diff --git a/pymysql/auth/__init__.py b/pymysql/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymysql/auth/sha256_password_plugin.py b/pymysql/auth/sha256_password_plugin.py new file mode 100644 index 00000000..539b9249 --- /dev/null +++ b/pymysql/auth/sha256_password_plugin.py @@ -0,0 +1,58 @@ +from ..constants import CLIENT +from ..err import OperationalError + +# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm +# which is needed when use sha256_password_plugin with no SSL +try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization, hashes + from cryptography.hazmat.primitives.asymmetric import padding + HAVE_CRYPTOGRAPHY = True +except ImportError: + HAVE_CRYPTOGRAPHY = False + + +def _xor_password(password, salt): + password_bytes = bytearray(password, 'ascii') + salt_len = len(salt) + for i in range(len(password_bytes)): + password_bytes[i] ^= ord(salt[i % salt_len]) + return password_bytes + + +def _sha256_rsa_crypt(password, salt, public_key): + if not HAVE_CRYPTOGRAPHY: + raise OperationalError("cryptography module not found for sha256_password_plugin") + message = _xor_password(password + b'\0', salt) + rsa_key = serialization.load_pem_public_key(public_key, default_backend()) + return rsa_key.encrypt( + message.decode('latin1').encode('latin1'), padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA1()), + algorithm=hashes.SHA1(), + label=None)) + + +class SHA256PasswordPlugin(object): + def __init__(self, con): + self.con = con + + def authenticate(self, pkt): + if self.con.ssl and self.con.server_capabilities & CLIENT.SSL: + data = self.con.password.encode('latin1') + b'\0' + else: + if pkt.is_auth_switch_request(): + self.con.salt = pkt.read_all() + if self.con.server_public_key == '': + self.con.write_packet(b'\1') + pkt = self.con._read_packet() + if pkt.is_extra_auth_data() and self.con.server_public_key == '': + pkt.read_uint8() + self.con.server_public_key = pkt.read_all() + data = _sha256_rsa_crypt( + self.con.password, + self.con.salt, + self.con.server_public_key) + self.con.write_packet(data) + pkt = self.con._read_packet() + pkt.check_error() + return pkt diff --git a/pymysql/connections.py b/pymysql/connections.py index 967d0c59..0f56bf59 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -16,6 +16,7 @@ import traceback import warnings +from .auth import sha256_password_plugin as _auth from .charset import MBLENGTH, charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS from . import converters @@ -39,7 +40,6 @@ # KeyError occurs when there's no entry in OS database for a current user. DEFAULT_USER = None - DEBUG = False _py_version = sys.version_info[:2] @@ -135,7 +135,6 @@ def _scramble(password, message): result = s.digest() return _my_crypt(result, stage1) - def _my_crypt(message1, message2): length = len(message1) result = b'' @@ -365,6 +364,10 @@ def is_auth_switch_request(self): # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest return self._data[0:1] == b'\xfe' + def is_extra_auth_data(self): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + return self._data[0:1] == b'\x01' + def is_resultset_packet(self): field_count = ord(self._data[0:1]) return 1 <= field_count <= 250 @@ -559,6 +562,7 @@ class Connection(object): The class needs an authenticate method taking an authentication packet as an argument. For the dialog plugin, a prompt(echo, prompt) method can be used (if no authenticate method) for returning a string from the user. (experimental) + :param server_public_key: SHA256 authenticaiton plugin public key value. (default: '') :param db: Alias for database. (for compatibility to MySQLdb) :param passwd: Alias for password. (for compatibility to MySQLdb) :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False) @@ -581,7 +585,7 @@ def __init__(self, host=None, user=None, password="", autocommit=False, db=None, passwd=None, local_infile=False, max_allowed_packet=16*1024*1024, defer_connect=False, auth_plugin_map={}, read_timeout=None, write_timeout=None, - bind_address=None, binary_prefix=False): + bind_address=None, binary_prefix=False, server_public_key=''): if no_delay is not None: warnings.warn("no_delay option is deprecated", DeprecationWarning) @@ -697,6 +701,9 @@ def _config(key, arg): self.init_command = init_command self.max_allowed_packet = max_allowed_packet self._auth_plugin_map = auth_plugin_map + if b"sha256_password" not in self._auth_plugin_map: + self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin + self.server_public_key = server_public_key self._binary_prefix = binary_prefix if defer_connect: self._sock = None @@ -826,7 +833,7 @@ def select_db(self, db): def escape(self, obj, mapping=None): """Escape whatever value you pass to it. - + Non-standard, for internal use; do not use this in your applications. """ if isinstance(obj, str_type): @@ -840,7 +847,7 @@ def escape(self, obj, mapping=None): def literal(self, obj): """Alias for escape() - + Non-standard, for internal use; do not use this in your applications. """ return self.escape(obj, self.encoders) @@ -1180,6 +1187,14 @@ def _request_authentication(self): authresp = b'' if self._auth_plugin_name in ('', 'mysql_native_password'): authresp = _scramble(self.password.encode('latin1'), self.salt) + elif self._auth_plugin_name == 'sha256_password': + if self.ssl and self.server_capabilities & CLIENT.SSL: + authresp = self.password.encode('latin1') + b'\0' + else: + if self.password is not None: + authresp = b'\1' + else: + authresp = b'\0' if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: data += lenenc_int(len(authresp)) + authresp @@ -1215,24 +1230,20 @@ def _request_authentication(self): data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0' self.write_packet(data) auth_packet = self._read_packet() + elif auth_packet.is_extra_auth_data(): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + handler = self._get_auth_plugin_handler(self._auth_plugin_name) + handler.authenticate(auth_packet) def _process_auth(self, plugin_name, auth_packet): - plugin_class = self._auth_plugin_map.get(plugin_name) - if not plugin_class: - plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) - if plugin_class: + handler = self._get_auth_plugin_handler(plugin_name) + if handler != None: try: - handler = plugin_class(self) return handler.authenticate(auth_packet) except AttributeError: if plugin_name != b'dialog': raise err.OperationalError(2059, "Authentication plugin '%s'" \ " not loaded: - %r missing authenticate method" % (plugin_name, plugin_class)) - except TypeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" \ - " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) - else: - handler = None if plugin_name == b"mysql_native_password": # https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) @@ -1277,6 +1288,20 @@ def _process_auth(self, plugin_name, auth_packet): pkt = self._read_packet() pkt.check_error() return pkt + + def _get_auth_plugin_handler(self, plugin_name): + plugin_class = self._auth_plugin_map.get(plugin_name) + if not plugin_class: + plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) + if plugin_class: + try: + handler = plugin_class(self) + except TypeError: + raise err.OperationalError(2059, "Authentication plugin '%s'" \ + " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) + else: + handler = None + return handler # _mysql support def thread_id(self): @@ -1551,7 +1576,7 @@ def _get_descriptions(self): # This behavior is different from TEXT / BLOB. # We should decode result by connection encoding regardless charsetnr. # See https://github.com/PyMySQL/PyMySQL/issues/488 - encoding = conn_encoding # SELECT CAST(... AS JSON) + encoding = conn_encoding # SELECT CAST(... AS JSON) elif field_type in TEXT_TYPES: if field.charsetnr == 63: # binary # TEXTs with charset=binary means BINARY types. diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index c626a0d3..28091be2 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -360,11 +360,12 @@ def testAuthSHA256(self): else: c.execute('SET old_passwords = 2') c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')") + c.execute("FLUSH PRIVILEGES") db = self.db.copy() db['password'] = "Sh@256Pa33" - # not implemented yet so thows error + # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test. with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_256', **db) + pymysql.connect(user='pymysql_sha256', **db) class TestConnection(base.PyMySQLTestCase):