From 695df63ec71634952e8b9ec6c0ccd83d6e0550a8 Mon Sep 17 00:00:00 2001 From: lishuode Date: Thu, 22 Jun 2017 07:28:15 +0000 Subject: [PATCH 1/2] Add support for SHA256 auth plugin --- pymysql/auth/__init__.py | 0 pymysql/auth/sha256_password_plugin.py | 58 ++++++++++++++++++++++++++ pymysql/connections.py | 57 ++++++++++++++++++------- pymysql/tests/test_connection.py | 5 ++- 4 files changed, 102 insertions(+), 18 deletions(-) create mode 100644 pymysql/auth/__init__.py create mode 100644 pymysql/auth/sha256_password_plugin.py diff --git a/pymysql/auth/__init__.py b/pymysql/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/auth/sha256_password_plugin.py new file mode 100644 index 00000000..539b9249 --- /dev/null +++ b/pymysql/auth/sha256_password_plugin.py @@ -0,0 +1,58 @@ +from ..constants import CLIENT +from ..err import OperationalError + +# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm +# which is needed when use sha256_password_plugin with no SSL +try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization, hashes + from cryptography.hazmat.primitives.asymmetric import padding + HAVE_CRYPTOGRAPHY = True +except ImportError: + HAVE_CRYPTOGRAPHY = False + + +def _xor_password(password, salt): + password_bytes = bytearray(password, 'ascii') + salt_len = len(salt) + for i in range(len(password_bytes)): + password_bytes[i] ^= ord(salt[i % salt_len]) + return password_bytes + + +def _sha256_rsa_crypt(password, salt, public_key): + if not HAVE_CRYPTOGRAPHY: + raise OperationalError("cryptography module not found for sha256_password_plugin") + message = _xor_password(password + b'\0', salt) + rsa_key = serialization.load_pem_public_key(public_key, default_backend()) + return rsa_key.encrypt( + message.decode('latin1').encode('latin1'), padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA1()), + algorithm=hashes.SHA1(), + label=None)) + + +class SHA256PasswordPlugin(object): + def __init__(self, con): + self.con = con + + def authenticate(self, pkt): + if self.con.ssl and self.con.server_capabilities & CLIENT.SSL: + data = self.con.password.encode('latin1') + b'\0' + else: + if pkt.is_auth_switch_request(): + self.con.salt = pkt.read_all() + if self.con.server_public_key == '': + self.con.write_packet(b'\1') + pkt = self.con._read_packet() + if pkt.is_extra_auth_data() and self.con.server_public_key == '': + pkt.read_uint8() + self.con.server_public_key = pkt.read_all() + data = _sha256_rsa_crypt( + self.con.password, + self.con.salt, + self.con.server_public_key) + self.con.write_packet(data) + pkt = self.con._read_packet() + pkt.check_error() + return pkt diff --git a/pymysql/connections.py b/pymysql/connections.py index ac16c993..f9e91a30 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -16,6 +16,7 @@ import traceback import warnings +from .auth import sha256_password_plugin as _auth from .charset import MBLENGTH, charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS from .converters import escape_item, escape_string, through, conversions as _conv @@ -39,7 +40,6 @@ # KeyError occurs when there's no entry in OS database for a current user. DEFAULT_USER = None - DEBUG = False _py_version = sys.version_info[:2] @@ -144,7 +144,6 @@ def _scramble(password, message): result = s.digest() return _my_crypt(result, stage1) - def _my_crypt(message1, message2): length = len(message1) result = b'' @@ -374,6 +373,10 @@ def is_auth_switch_request(self): # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest return self._data[0:1] == b'\xfe' + def is_extra_auth_data(self): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + return self._data[0:1] == b'\x01' + def is_resultset_packet(self): field_count = ord(self._data[0:1]) return 1 <= field_count <= 250 @@ -566,6 +569,7 @@ class Connection(object): The class needs an authenticate method taking an authentication packet as an argument. For the dialog plugin, a prompt(echo, prompt) method can be used (if no authenticate method) for returning a string from the user. (experimental) + :param server_public_key: SHA256 authenticaiton plugin public key value. (default: '') :param db: Alias for database. (for compatibility to MySQLdb) :param passwd: Alias for password. (for compatibility to MySQLdb) """ @@ -584,7 +588,7 @@ def __init__(self, host=None, user=None, password="", autocommit=False, db=None, passwd=None, local_infile=False, max_allowed_packet=16*1024*1024, defer_connect=False, auth_plugin_map={}, read_timeout=None, write_timeout=None, - bind_address=None): + server_public_key='', bind_address=None): if no_delay is not None: warnings.warn("no_delay option is deprecated", DeprecationWarning) @@ -699,6 +703,9 @@ def _config(key, arg): self.init_command = init_command self.max_allowed_packet = max_allowed_packet self._auth_plugin_map = auth_plugin_map + if b"sha256_password" not in self._auth_plugin_map: + self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin + self.server_public_key = server_public_key if defer_connect: self._sock = None else: @@ -805,7 +812,7 @@ def select_db(self, db): def escape(self, obj, mapping=None): """Escape whatever value you pass to it. - + Non-standard, for internal use; do not use this in your applications. """ if isinstance(obj, str_type): @@ -814,7 +821,7 @@ def escape(self, obj, mapping=None): def literal(self, obj): """Alias for escape() - + Non-standard, for internal use; do not use this in your applications. """ return self.escape(obj, self.encoders) @@ -1128,6 +1135,14 @@ def _request_authentication(self): authresp = b'' if self._auth_plugin_name in ('', 'mysql_native_password'): authresp = _scramble(self.password.encode('latin1'), self.salt) + elif self._auth_plugin_name == 'sha256_password': + if self.ssl and self.server_capabilities & CLIENT.SSL: + authresp = self.password.encode('latin1') + b'\0' + else: + if self.password is not None: + authresp = b'\1' + else: + authresp = b'\0' if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: data += lenenc_int(len(authresp)) + authresp @@ -1163,24 +1178,20 @@ def _request_authentication(self): data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0' self.write_packet(data) auth_packet = self._read_packet() + elif auth_packet.is_extra_auth_data(): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + handler = self._get_auth_plugin_handler(self._auth_plugin_name) + handler.authenticate(auth_packet) def _process_auth(self, plugin_name, auth_packet): - plugin_class = self._auth_plugin_map.get(plugin_name) - if not plugin_class: - plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) - if plugin_class: + handler = self._get_auth_plugin_handler(plugin_name) + if handler != None: try: - handler = plugin_class(self) return handler.authenticate(auth_packet) except AttributeError: if plugin_name != b'dialog': raise err.OperationalError(2059, "Authentication plugin '%s'" \ " not loaded: - %r missing authenticate method" % (plugin_name, plugin_class)) - except TypeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" \ - " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) - else: - handler = None if plugin_name == b"mysql_native_password": # https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) @@ -1225,6 +1236,20 @@ def _process_auth(self, plugin_name, auth_packet): pkt = self._read_packet() pkt.check_error() return pkt + + def _get_auth_plugin_handler(self, plugin_name): + plugin_class = self._auth_plugin_map.get(plugin_name) + if not plugin_class: + plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) + if plugin_class: + try: + handler = plugin_class(self) + except TypeError: + raise err.OperationalError(2059, "Authentication plugin '%s'" \ + " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) + else: + handler = None + return handler # _mysql support def thread_id(self): @@ -1490,7 +1515,7 @@ def _get_descriptions(self): # This behavior is different from TEXT / BLOB. # We should decode result by connection encoding regardless charsetnr. # See https://github.com/PyMySQL/PyMySQL/issues/488 - encoding = conn_encoding # SELECT CAST(... AS JSON) + encoding = conn_encoding # SELECT CAST(... AS JSON) elif field_type in TEXT_TYPES: if field.charsetnr == 63: # binary # TEXTs with charset=binary means BINARY types. diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 518b6fe7..f1cda77a 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -355,11 +355,12 @@ def testAuthSHA256(self): else: c.execute('SET old_passwords = 2') c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')") + c.execute("FLUSH PRIVILEGES") db = self.db.copy() db['password'] = "Sh@256Pa33" - # not implemented yet so thows error + # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test. with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_256', **db) + pymysql.connect(user='pymysql_sha256', **db) class TestConnection(base.PyMySQLTestCase): From c93b3222ab3af4fccd5d15b2ae5b9de91fe1c499 Mon Sep 17 00:00:00 2001 From: lishuode Date: Thu, 10 Aug 2017 02:25:29 +0000 Subject: [PATCH 2/2] travis: MySQL 5.6.35 => 5.6.37, MySQL 5.7.17 => 5.7.19 --- .travis.yml | 4 ++-- .travis/initializedb.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 65888224..2b63f9f2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,7 +24,7 @@ matrix: python: "2.7" - env: - - DB=5.6.35 + - DB=5.6.37 addons: apt: packages: @@ -32,7 +32,7 @@ matrix: python: "3.3" - env: - - DB=5.7.17 + - DB=5.7.19 addons: apt: packages: diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh index df0e900b..32cf9f9f 100755 --- a/.travis/initializedb.sh +++ b/.travis/initializedb.sh @@ -9,7 +9,7 @@ if [ ! -z "${DB}" ]; then # disable existing database server in case of accidential connection mysql -u root -e 'drop user travis@localhost; drop user root@localhost; drop user travis; create user super@localhost; grant all on *.* to super@localhost with grant option' mysql -u super -e 'drop user root' - F=mysql-${DB}-linux-glibc2.5-x86_64 + F=mysql-${DB}-linux-glibc2.12-x86_64 mkdir -p ${HOME}/mysql P=${HOME}/mysql/${F} if [ ! -d "${P}" ]; then