diff --git a/.gitignore b/.gitignore index cd93a4b0..0b2c85be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,14 @@ *.pyc *.pyo -__pycache__ -.coverage -/dist -/PyMySQL.egg-info +/.cache +/.coverage +/.idea /.tox +/.venv +/.vscode +/PyMySQL.egg-info /build +/dist +/docs/build /pymysql/tests/databases.json - -/.idea -docs/build +__pycache__ diff --git a/.travis.yml b/.travis.yml index 2822cd05..8d960249 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,13 +35,14 @@ matrix: python: "3.4" - env: - DB=mysql:8.0 + - TEST_AUTH=yes python: "3.7-dev" # different py version from 5.6 and 5.7 as cache seems to be based on py version # http://dev.mysql.com/downloads/mysql/5.7.html has latest development release version # really only need libaio1 for DB builds however libaio-dev is whitelisted for container builds and liaio1 isn't install: - - pip install -U coveralls unittest2 coverage + - pip install -U coveralls unittest2 coverage cryptography pytest before_script: - ./.travis/initializedb.sh @@ -51,6 +52,9 @@ before_script: script: - coverage run ./runtests.py + - if [ "${TEST_AUTH}" = "yes" ]; + then pytest -v tests; + fi - if [ ! -z "${DB}" ]; then docker logs mysqld; fi diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh index 18c00eca..d9897e49 100755 --- a/.travis/initializedb.sh +++ b/.travis/initializedb.sh @@ -37,6 +37,16 @@ if [ ! -z "${DB}" ]; then docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}" docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}" docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}" + + # Test user for auth test + mysql -e ' + CREATE USER + user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256", + nopass_sha256 IDENTIFIED WITH "sha256_password", + user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2", + nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password" + PASSWORD EXPIRE NEVER;' + mysql -e 'GRANT RELOAD ON *.* TO user_caching_sha2;' else WITH_PLUGIN='' fi diff --git a/pymysql/_auth.py b/pymysql/_auth.py new file mode 100644 index 00000000..ddf6e4e5 --- /dev/null +++ b/pymysql/_auth.py @@ -0,0 +1,252 @@ +""" +Implements auth methods +""" +from ._compat import text_type +from .constants import CLIENT +from .err import OperationalError + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding + +from functools import partial +import hashlib +import struct + + +DEBUG = True +SCRAMBLE_LENGTH = 20 +sha1_new = partial(hashlib.new, 'sha1') + + +# mysql_native_password +# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 + + +def scramble_native_password(password, message): + """Scramble used for mysql_native_password""" + if not password: + return b'' + + stage1 = sha1_new(password).digest() + stage2 = sha1_new(stage1).digest() + s = sha1_new() + s.update(message[:SCRAMBLE_LENGTH]) + s.update(stage2) + result = s.digest() + return _my_crypt(result, stage1) + + +def _my_crypt(message1, message2): + length = len(message1) + result = b'' + for i in range(length): + x = ( + struct.unpack('B', message1[i:i + 1])[0] ^ + struct.unpack('B', message2[i:i + 1])[0] + ) + result += struct.pack('B', x) + return result + + +# old_passwords support ported from libmysql/password.c +# https://dev.mysql.com/doc/internals/en/old-password-authentication.html + +SCRAMBLE_LENGTH_323 = 8 + + +class RandStruct_323(object): + + def __init__(self, seed1, seed2): + self.max_value = 0x3FFFFFFF + self.seed1 = seed1 % self.max_value + self.seed2 = seed2 % self.max_value + + def my_rnd(self): + self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value + self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value + return float(self.seed1) / float(self.max_value) + + +def scramble_old_password(password, message): + """Scramble for old_password""" + hash_pass = _hash_password_323(password) + hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) + hash_pass_n = struct.unpack(">LL", hash_pass) + hash_message_n = struct.unpack(">LL", hash_message) + + rand_st = RandStruct_323( + hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1] + ) + outbuf = io.BytesIO() + for _ in range(min(SCRAMBLE_LENGTH_323, len(message))): + outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) + extra = int2byte(int(rand_st.my_rnd() * 31)) + out = outbuf.getvalue() + outbuf = io.BytesIO() + for c in out: + outbuf.write(int2byte(byte2int(c) ^ byte2int(extra))) + return outbuf.getvalue() + + +def _hash_password_323(password): + nr = 1345345333 + add = 7 + nr2 = 0x12345671 + + # x in py3 is numbers, p27 is chars + for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]: + nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF + nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF + add = (add + c) & 0xFFFFFFFF + + r1 = nr & ((1 << 31) - 1) # kill sign bits + r2 = nr2 & ((1 << 31) - 1) + return struct.pack(">LL", r1, r2) + + +# sha256_password + + +def _roundtrip(conn, send_data): + conn.write_packet(send_data) + pkt = conn._read_packet() + pkt.check_error() + return pkt + + +def _xor_password(password, salt): + password_bytes = bytearray(password) + salt = bytearray(salt) # for PY2 compat. + salt_len = len(salt) + for i in range(len(password_bytes)): + password_bytes[i] ^= salt[i % salt_len] + return bytes(password_bytes) + + +def sha2_rsa_encrypt(password, salt, public_key): + """Encrypt password with salt and public_key. + + Used for sha256_password and caching_sha2_password. + """ + message = _xor_password(password + b'\0', salt) + rsa_key = serialization.load_pem_public_key(public_key, default_backend()) + return rsa_key.encrypt( + message, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA1()), + algorithm=hashes.SHA1(), + label=None, + ), + ) + + +def sha256_password_auth(conn, pkt): + if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if DEBUG: + print("sha256: Sending plain password") + data = conn.password + b'\0' + return _roundtrip(conn, data) + + if pkt.is_auth_switch_request(): + conn.salt = pkt.read_all() + if not conn.server_public_key and conn.password: + # Request server public key + if DEBUG: + print("sha256: Requesting server public key") + pkt = _roundtrip(conn, b'\1') + + if pkt.is_extra_auth_data(): + conn.server_public_key = pkt._data[1:] + if DEBUG: + print("Received public key:\n", conn.server_public_key.decode('ascii')) + + if conn.password: + if not conn.server_public_key: + raise OperationalError("Couldn't receive server's public key") + + data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) + else: + data = b'' + + return _roundtrip(conn, data) + + +def scramble_caching_sha2(password, nonce): + # (bytes, bytes) -> bytes + """Scramble algorithm used in cached_sha2_password fast path. + + XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce)) + """ + if not password: + return b'' + + p1 = hashlib.sha256(password).digest() + p2 = hashlib.sha256(p1).digest() + p3 = hashlib.sha256(p2 + nonce).digest() + + res = bytearray(p1) + for i in range(len(p3)): + res[i] ^= p3[i] + + return bytes(res) + + +def caching_sha2_password_auth(conn, pkt): + # No password fast path + if not conn.password: + return _roundtrip(conn, b'') + + if pkt.is_auth_switch_request(): + # Try from fast auth + if DEBUG: + print("caching sha2: Trying fast path") + conn.salt = pkt.read_all() + scrambled = scramble_caching_sha2(conn.password, conn.salt) + pkt = _roundtrip(conn, scrambled) + # else: fast auth is tried in initial handshake + + if not pkt.is_extra_auth_data(): + raise OperationalError( + "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1] + ) + + # magic numbers: + # 2 - request public key + # 3 - fast auth succeeded + # 4 - need full auth + + pkt.advance(1) + n = pkt.read_uint8() + + if n == 3: + if DEBUG: + print("caching sha2: succeeded by fast path.") + pkt = conn._read_packet() + pkt.check_error() # pkt must be OK packet + return pkt + + if n != 4: + raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n) + + if DEBUG: + print("caching sha2: Trying full auth...") + + if conn.ssl and conn.server_capabilities & CLIENT.SSL: + if DEBUG: + print("caching sha2: Sending plain password via SSL") + return _roundtrip(conn, conn.password + b'\0') + + if not conn.server_public_key: + pkt = _roundtrip(conn, b'\x02') # Request public key + if not pkt.is_extra_auth_data(): + raise OperationalError( + "caching sha2: Unknown packet for public key: %s" % pkt._data[:1] + ) + + conn.server_public_key = pkt._data[1:] + if DEBUG: + print(conn.server_public_key.decode('ascii')) + + data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) + pkt = _roundtrip(conn, data) diff --git a/pymysql/connections.py b/pymysql/connections.py index d5dc0aff..14ae76d9 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -6,8 +6,6 @@ from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON import errno -from functools import partial -import hashlib import io import os import socket @@ -16,6 +14,8 @@ import traceback import warnings +from . import _auth + from .charset import charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS from . import converters @@ -43,7 +43,6 @@ # KeyError occurs when there's no entry in OS database for a current user. DEFAULT_USER = None - DEBUG = False _py_version = sys.version_info[:2] @@ -87,90 +86,16 @@ def _makefile(sock, mode): FIELD_TYPE.VARCHAR, FIELD_TYPE.GEOMETRY]) -sha_new = partial(hashlib.new, 'sha1') -DEFAULT_CHARSET = 'latin1' +DEFAULT_CHARSET = 'latin1' # TODO: change to utf8mb4 MAX_PACKET_LEN = 2**24-1 -SCRAMBLE_LENGTH = 20 - -def _scramble(password, message): - if not password: - return b'' - if DEBUG: print('password=' + str(password)) - stage1 = sha_new(password).digest() - stage2 = sha_new(stage1).digest() - s = sha_new() - s.update(message[:SCRAMBLE_LENGTH]) - s.update(stage2) - result = s.digest() - return _my_crypt(result, stage1) - - -def _my_crypt(message1, message2): - length = len(message1) - result = b'' - for i in range_type(length): - x = (struct.unpack('B', message1[i:i+1])[0] ^ - struct.unpack('B', message2[i:i+1])[0]) - result += struct.pack('B', x) - return result - -# old_passwords support ported from libmysql/password.c -SCRAMBLE_LENGTH_323 = 8 - - -class RandStruct_323(object): - def __init__(self, seed1, seed2): - self.max_value = 0x3FFFFFFF - self.seed1 = seed1 % self.max_value - self.seed2 = seed2 % self.max_value - - def my_rnd(self): - self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value - self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value - return float(self.seed1) / float(self.max_value) - - -def _scramble_323(password, message): - hash_pass = _hash_password_323(password) - hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) - hash_pass_n = struct.unpack(">LL", hash_pass) - hash_message_n = struct.unpack(">LL", hash_message) - - rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0], - hash_pass_n[1] ^ hash_message_n[1]) - outbuf = io.BytesIO() - for _ in range_type(min(SCRAMBLE_LENGTH_323, len(message))): - outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) - extra = int2byte(int(rand_st.my_rnd() * 31)) - out = outbuf.getvalue() - outbuf = io.BytesIO() - for c in out: - outbuf.write(int2byte(byte2int(c) ^ byte2int(extra))) - return outbuf.getvalue() - - -def _hash_password_323(password): - nr = 1345345333 - add = 7 - nr2 = 0x12345671 - - # x in py3 is numbers, p27 is chars - for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]: - nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF - nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF - add = (add + c) & 0xFFFFFFFF - - r1 = nr & ((1 << 31) - 1) # kill sign bits - r2 = nr2 & ((1 << 31) - 1) - return struct.pack(">LL", r1, r2) - def pack_int24(n): return struct.pack('`_ in the specification. - + :raise Error: If the connection is already closed. """ if self._closed: @@ -446,7 +376,7 @@ def _force_close(self): if self._sock: try: self._sock.close() - except: + except: # noqa pass self._sock = None self._rfile = None @@ -485,7 +415,7 @@ def begin(self): def commit(self): """ Commit changes to stable storage. - + See `Connection.commit() `_ in the specification. """ @@ -495,7 +425,7 @@ def commit(self): def rollback(self): """ Roll back the current transaction. - + See `Connection.rollback() `_ in the specification. """ @@ -512,7 +442,7 @@ def show_warnings(self): def select_db(self, db): """ Set current db. - + :param db: The name of the db. """ self._execute_command(COMMAND.COM_INIT_DB, db) @@ -520,7 +450,7 @@ def select_db(self, db): def escape(self, obj, mapping=None): """Escape whatever value you pass to it. - + Non-standard, for internal use; do not use this in your applications. """ if isinstance(obj, str_type): @@ -534,7 +464,7 @@ def escape(self, obj, mapping=None): def literal(self, obj): """Alias for escape() - + Non-standard, for internal use; do not use this in your applications. """ return self.escape(obj, self.encoders) @@ -554,7 +484,7 @@ def _quote_bytes(self, s): def cursor(self, cursor=None): """ Create a new cursor to execute queries with. - + :param cursor: The type of cursor to create; one of :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`, or :py:class:`SSDictCursor`. None means use Cursor. @@ -602,7 +532,7 @@ def kill(self, thread_id): def ping(self, reconnect=True): """ Check if the server is alive. - + :param reconnect: If the connection is closed, reconnect. :raise Error: If the connection is closed and reconnect=False. """ @@ -684,7 +614,7 @@ def connect(self, sock=None): if sock is not None: try: sock.close() - except: + except: # noqa pass if isinstance(e, (OSError, IOError, socket.error)): @@ -811,7 +741,6 @@ def _execute_command(self, command, sql): :raise InterfaceError: If the connection is closed. :raise ValueError: If no username was specified. """ - if not self._sock: raise err.InterfaceError("(0, '')") @@ -861,7 +790,7 @@ def _request_authentication(self): if isinstance(self.user, text_type): self.user = self.user.encode(self.encoding) - data_init = struct.pack('