Skip to content

Add support for SHA256 auth plugin #583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added pymysql/auth/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions pymysql/auth/sha256_password_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from ..constants import CLIENT
from ..err import OperationalError

# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm
# which is needed when use sha256_password_plugin with no SSL
try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
HAVE_CRYPTOGRAPHY = True
except ImportError:
HAVE_CRYPTOGRAPHY = False


def _xor_password(password, salt):
password_bytes = bytearray(password, 'ascii')
salt_len = len(salt)
for i in range(len(password_bytes)):
password_bytes[i] ^= ord(salt[i % salt_len])
return password_bytes


def _sha256_rsa_crypt(password, salt, public_key):
if not HAVE_CRYPTOGRAPHY:
raise OperationalError("cryptography module not found for sha256_password_plugin")
message = _xor_password(password + b'\0', salt)
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
return rsa_key.encrypt(
message.decode('latin1').encode('latin1'), padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None))


class SHA256PasswordPlugin(object):
def __init__(self, con):
self.con = con

def authenticate(self, pkt):
if self.con.ssl and self.con.server_capabilities & CLIENT.SSL:
data = self.con.password.encode('latin1') + b'\0'
else:
if pkt.is_auth_switch_request():
self.con.salt = pkt.read_all()
if self.con.server_public_key == '':
self.con.write_packet(b'\1')
pkt = self.con._read_packet()
if pkt.is_extra_auth_data() and self.con.server_public_key == '':
pkt.read_uint8()
self.con.server_public_key = pkt.read_all()
data = _sha256_rsa_crypt(
self.con.password,
self.con.salt,
self.con.server_public_key)
self.con.write_packet(data)
pkt = self.con._read_packet()
pkt.check_error()
return pkt
57 changes: 41 additions & 16 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import traceback
import warnings

from .auth import sha256_password_plugin as _auth
from .charset import MBLENGTH, charset_by_name, charset_by_id
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
from . import converters
Expand All @@ -39,7 +40,6 @@
# KeyError occurs when there's no entry in OS database for a current user.
DEFAULT_USER = None


DEBUG = False

_py_version = sys.version_info[:2]
Expand Down Expand Up @@ -135,7 +135,6 @@ def _scramble(password, message):
result = s.digest()
return _my_crypt(result, stage1)


def _my_crypt(message1, message2):
length = len(message1)
result = b''
Expand Down Expand Up @@ -365,6 +364,10 @@ def is_auth_switch_request(self):
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
return self._data[0:1] == b'\xfe'

def is_extra_auth_data(self):
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
return self._data[0:1] == b'\x01'

def is_resultset_packet(self):
field_count = ord(self._data[0:1])
return 1 <= field_count <= 250
Expand Down Expand Up @@ -559,6 +562,7 @@ class Connection(object):
The class needs an authenticate method taking an authentication packet as
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
(if no authenticate method) for returning a string from the user. (experimental)
:param server_public_key: SHA256 authenticaiton plugin public key value. (default: '')
:param db: Alias for database. (for compatibility to MySQLdb)
:param passwd: Alias for password. (for compatibility to MySQLdb)
:param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
Expand All @@ -581,7 +585,7 @@ def __init__(self, host=None, user=None, password="",
autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False,
auth_plugin_map={}, read_timeout=None, write_timeout=None,
bind_address=None, binary_prefix=False):
bind_address=None, binary_prefix=False, server_public_key=''):
if no_delay is not None:
warnings.warn("no_delay option is deprecated", DeprecationWarning)

Expand Down Expand Up @@ -697,6 +701,9 @@ def _config(key, arg):
self.init_command = init_command
self.max_allowed_packet = max_allowed_packet
self._auth_plugin_map = auth_plugin_map
if b"sha256_password" not in self._auth_plugin_map:
self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin
self.server_public_key = server_public_key
self._binary_prefix = binary_prefix
if defer_connect:
self._sock = None
Expand Down Expand Up @@ -826,7 +833,7 @@ def select_db(self, db):

def escape(self, obj, mapping=None):
"""Escape whatever value you pass to it.

Non-standard, for internal use; do not use this in your applications.
"""
if isinstance(obj, str_type):
Expand All @@ -840,7 +847,7 @@ def escape(self, obj, mapping=None):

def literal(self, obj):
"""Alias for escape()

Non-standard, for internal use; do not use this in your applications.
"""
return self.escape(obj, self.encoders)
Expand Down Expand Up @@ -1180,6 +1187,14 @@ def _request_authentication(self):
authresp = b''
if self._auth_plugin_name in ('', 'mysql_native_password'):
authresp = _scramble(self.password.encode('latin1'), self.salt)
elif self._auth_plugin_name == 'sha256_password':
if self.ssl and self.server_capabilities & CLIENT.SSL:
authresp = self.password.encode('latin1') + b'\0'
else:
if self.password is not None:
authresp = b'\1'
else:
authresp = b'\0'

if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp)) + authresp
Expand Down Expand Up @@ -1215,24 +1230,20 @@ def _request_authentication(self):
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
self.write_packet(data)
auth_packet = self._read_packet()
elif auth_packet.is_extra_auth_data():
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
handler = self._get_auth_plugin_handler(self._auth_plugin_name)
handler.authenticate(auth_packet)

def _process_auth(self, plugin_name, auth_packet):
plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class:
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
if plugin_class:
handler = self._get_auth_plugin_handler(plugin_name)
if handler != None:
try:
handler = plugin_class(self)
return handler.authenticate(auth_packet)
except AttributeError:
if plugin_name != b'dialog':
raise err.OperationalError(2059, "Authentication plugin '%s'" \
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
except TypeError:
raise err.OperationalError(2059, "Authentication plugin '%s'" \
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
else:
handler = None
if plugin_name == b"mysql_native_password":
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
data = _scramble(self.password.encode('latin1'), auth_packet.read_all())
Expand Down Expand Up @@ -1277,6 +1288,20 @@ def _process_auth(self, plugin_name, auth_packet):
pkt = self._read_packet()
pkt.check_error()
return pkt

def _get_auth_plugin_handler(self, plugin_name):
plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class:
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
if plugin_class:
try:
handler = plugin_class(self)
except TypeError:
raise err.OperationalError(2059, "Authentication plugin '%s'" \
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
else:
handler = None
return handler

# _mysql support
def thread_id(self):
Expand Down Expand Up @@ -1551,7 +1576,7 @@ def _get_descriptions(self):
# This behavior is different from TEXT / BLOB.
# We should decode result by connection encoding regardless charsetnr.
# See https://github.com/PyMySQL/PyMySQL/issues/488
encoding = conn_encoding # SELECT CAST(... AS JSON)
encoding = conn_encoding # SELECT CAST(... AS JSON)
elif field_type in TEXT_TYPES:
if field.charsetnr == 63: # binary
# TEXTs with charset=binary means BINARY types.
Expand Down
5 changes: 3 additions & 2 deletions pymysql/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,12 @@ def testAuthSHA256(self):
else:
c.execute('SET old_passwords = 2')
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')")
c.execute("FLUSH PRIVILEGES")
db = self.db.copy()
db['password'] = "Sh@256Pa33"
# not implemented yet so thows error
# Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test.
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_256', **db)
pymysql.connect(user='pymysql_sha256', **db)

class TestConnection(base.PyMySQLTestCase):

Expand Down