Skip to content

Commit b799bd9

Browse files
elemountmethane
authored andcommitted
Add support for SHA256 auth plugin
1 parent 14e4c25 commit b799bd9

File tree

5 files changed

+103
-18
lines changed

5 files changed

+103
-18
lines changed

pymysql/auth/__init__.py

Whitespace-only changes.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from ..constants import CLIENT
2+
from ..err import OperationalError
3+
4+
# Import cryptography for RSA_PKCS1_OAEP_PADDING algorithm
5+
# which is needed when use sha256_password_plugin with no SSL
6+
try:
7+
from cryptography.hazmat.backends import default_backend
8+
from cryptography.hazmat.primitives import serialization, hashes
9+
from cryptography.hazmat.primitives.asymmetric import padding
10+
HAVE_CRYPTOGRAPHY = True
11+
except ImportError:
12+
HAVE_CRYPTOGRAPHY = False
13+
14+
15+
def _xor_password(password, salt):
16+
password_bytes = bytearray(password, 'ascii')
17+
salt_len = len(salt)
18+
for i in range(len(password_bytes)):
19+
password_bytes[i] ^= ord(salt[i % salt_len])
20+
return password_bytes
21+
22+
23+
def _sha256_rsa_crypt(password, salt, public_key):
24+
if not HAVE_CRYPTOGRAPHY:
25+
raise OperationalError("cryptography module not found for sha256_password_plugin")
26+
message = _xor_password(password + b'\0', salt)
27+
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
28+
return rsa_key.encrypt(
29+
message.decode('latin1').encode('latin1'), padding.OAEP(
30+
mgf=padding.MGF1(algorithm=hashes.SHA1()),
31+
algorithm=hashes.SHA1(),
32+
label=None))
33+
34+
35+
class SHA256PasswordPlugin(object):
36+
def __init__(self, con):
37+
self.con = con
38+
39+
def authenticate(self, pkt):
40+
if self.con.ssl and self.con.server_capabilities & CLIENT.SSL:
41+
data = self.con.password.encode('latin1') + b'\0'
42+
else:
43+
if pkt.is_auth_switch_request():
44+
self.con.salt = pkt.read_all()
45+
if self.con.server_public_key == '':
46+
self.con.write_packet(b'\1')
47+
pkt = self.con._read_packet()
48+
if pkt.is_extra_auth_data() and self.con.server_public_key == '':
49+
pkt.read_uint8()
50+
self.con.server_public_key = pkt.read_all()
51+
data = _sha256_rsa_crypt(
52+
self.con.password,
53+
self.con.salt,
54+
self.con.server_public_key)
55+
self.con.write_packet(data)
56+
pkt = self.con._read_packet()
57+
pkt.check_error()
58+
return pkt

pymysql/connections.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import traceback
1717
import warnings
1818

19+
from .auth import sha256_password_plugin as _auth
1920
from .charset import charset_by_name, charset_by_id
2021
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
2122
from . import converters
@@ -43,7 +44,6 @@
4344
# KeyError occurs when there's no entry in OS database for a current user.
4445
DEFAULT_USER = None
4546

46-
4747
DEBUG = False
4848

4949
_py_version = sys.version_info[:2]
@@ -107,7 +107,6 @@ def _scramble(password, message):
107107
result = s.digest()
108108
return _my_crypt(result, stage1)
109109

110-
111110
def _my_crypt(message1, message2):
112111
length = len(message1)
113112
result = b''
@@ -186,6 +185,7 @@ def lenenc_int(i):
186185
else:
187186
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
188187

188+
189189
class Connection(object):
190190
"""
191191
Representation of a socket with a mysql server.
@@ -240,6 +240,7 @@ class Connection(object):
240240
The class needs an authenticate method taking an authentication packet as
241241
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
242242
(if no authenticate method) for returning a string from the user. (experimental)
243+
:param server_public_key: SHA256 authenticaiton plugin public key value. (default: '')
243244
:param db: Alias for database. (for compatibility to MySQLdb)
244245
:param passwd: Alias for password. (for compatibility to MySQLdb)
245246
:param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
@@ -262,7 +263,7 @@ def __init__(self, host=None, user=None, password="",
262263
autocommit=False, db=None, passwd=None, local_infile=False,
263264
max_allowed_packet=16*1024*1024, defer_connect=False,
264265
auth_plugin_map={}, read_timeout=None, write_timeout=None,
265-
bind_address=None, binary_prefix=False):
266+
bind_address=None, binary_prefix=False, server_public_key=''):
266267
if no_delay is not None:
267268
warnings.warn("no_delay option is deprecated", DeprecationWarning)
268269

@@ -379,6 +380,9 @@ def _config(key, arg):
379380
self.max_allowed_packet = max_allowed_packet
380381
self._auth_plugin_map = auth_plugin_map
381382
self._binary_prefix = binary_prefix
383+
if b"sha256_password" not in self._auth_plugin_map:
384+
self._auth_plugin_map[b"sha256_password"] = _auth.SHA256PasswordPlugin
385+
self.server_public_key = server_public_key
382386
if defer_connect:
383387
self._sock = None
384388
else:
@@ -507,7 +511,7 @@ def select_db(self, db):
507511

508512
def escape(self, obj, mapping=None):
509513
"""Escape whatever value you pass to it.
510-
514+
511515
Non-standard, for internal use; do not use this in your applications.
512516
"""
513517
if isinstance(obj, str_type):
@@ -521,7 +525,7 @@ def escape(self, obj, mapping=None):
521525

522526
def literal(self, obj):
523527
"""Alias for escape()
524-
528+
525529
Non-standard, for internal use; do not use this in your applications.
526530
"""
527531
return self.escape(obj, self.encoders)
@@ -861,6 +865,14 @@ def _request_authentication(self):
861865
authresp = b''
862866
if self._auth_plugin_name in ('', 'mysql_native_password'):
863867
authresp = _scramble(self.password.encode('latin1'), self.salt)
868+
elif self._auth_plugin_name == 'sha256_password':
869+
if self.ssl and self.server_capabilities & CLIENT.SSL:
870+
authresp = self.password.encode('latin1') + b'\0'
871+
else:
872+
if self.password is not None:
873+
authresp = b'\1'
874+
else:
875+
authresp = b'\0'
864876

865877
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
866878
data += lenenc_int(len(authresp)) + authresp
@@ -896,24 +908,20 @@ def _request_authentication(self):
896908
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
897909
self.write_packet(data)
898910
auth_packet = self._read_packet()
911+
elif auth_packet.is_extra_auth_data():
912+
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
913+
handler = self._get_auth_plugin_handler(self._auth_plugin_name)
914+
handler.authenticate(auth_packet)
899915

900916
def _process_auth(self, plugin_name, auth_packet):
901-
plugin_class = self._auth_plugin_map.get(plugin_name)
902-
if not plugin_class:
903-
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
904-
if plugin_class:
917+
handler = self._get_auth_plugin_handler(plugin_name)
918+
if handler != None:
905919
try:
906-
handler = plugin_class(self)
907920
return handler.authenticate(auth_packet)
908921
except AttributeError:
909922
if plugin_name != b'dialog':
910923
raise err.OperationalError(2059, "Authentication plugin '%s'" \
911924
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
912-
except TypeError:
913-
raise err.OperationalError(2059, "Authentication plugin '%s'" \
914-
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
915-
else:
916-
handler = None
917925
if plugin_name == b"mysql_native_password":
918926
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
919927
data = _scramble(self.password.encode('latin1'), auth_packet.read_all())
@@ -958,6 +966,20 @@ def _process_auth(self, plugin_name, auth_packet):
958966
pkt = self._read_packet()
959967
pkt.check_error()
960968
return pkt
969+
970+
def _get_auth_plugin_handler(self, plugin_name):
971+
plugin_class = self._auth_plugin_map.get(plugin_name)
972+
if not plugin_class:
973+
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
974+
if plugin_class:
975+
try:
976+
handler = plugin_class(self)
977+
except TypeError:
978+
raise err.OperationalError(2059, "Authentication plugin '%s'" \
979+
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
980+
else:
981+
handler = None
982+
return handler
961983

962984
# _mysql support
963985
def thread_id(self):
@@ -1232,7 +1254,7 @@ def _get_descriptions(self):
12321254
# This behavior is different from TEXT / BLOB.
12331255
# We should decode result by connection encoding regardless charsetnr.
12341256
# See https://github.com/PyMySQL/PyMySQL/issues/488
1235-
encoding = conn_encoding # SELECT CAST(... AS JSON)
1257+
encoding = conn_encoding # SELECT CAST(... AS JSON)
12361258
elif field_type in TEXT_TYPES:
12371259
if field.charsetnr == 63: # binary
12381260
# TEXTs with charset=binary means BINARY types.

pymysql/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ def is_auth_switch_request(self):
196196
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
197197
return self._data[0:1] == b'\xfe'
198198

199+
def is_extra_auth_data(self):
200+
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
201+
return self._data[0:1] == b'\x01'
202+
199203
def is_resultset_packet(self):
200204
field_count = ord(self._data[0:1])
201205
return 1 <= field_count <= 250

pymysql/tests/test_connection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,12 @@ def testAuthSHA256(self):
360360
else:
361361
c.execute('SET old_passwords = 2')
362362
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')")
363+
c.execute("FLUSH PRIVILEGES")
363364
db = self.db.copy()
364365
db['password'] = "Sh@256Pa33"
365-
# not implemented yet so thows error
366+
# Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test.
366367
with self.assertRaises(pymysql.err.OperationalError):
367-
pymysql.connect(user='pymysql_256', **db)
368+
pymysql.connect(user='pymysql_sha256', **db)
368369

369370
class TestConnection(base.PyMySQLTestCase):
370371

0 commit comments

Comments
 (0)