From d598a1550ce63e7ea6d0afa3e84191435ce9174d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 3 Jan 2021 12:01:32 +0900 Subject: [PATCH] Reformat with black --- pymysql/__init__.py | 125 ++- pymysql/_auth.py | 45 +- pymysql/_socketio.py | 16 +- pymysql/charset.py | 317 ++++--- pymysql/connections.py | 423 ++++++---- pymysql/constants/CLIENT.py | 13 +- pymysql/constants/COMMAND.py | 25 +- pymysql/constants/CR.py | 100 +-- pymysql/constants/FIELD_TYPE.py | 2 - pymysql/constants/SERVER_STATUS.py | 1 - pymysql/converters.py | 86 +- pymysql/cursors.py | 75 +- pymysql/err.py | 78 +- pymysql/optionfile.py | 4 +- pymysql/protocol.py | 111 ++- pymysql/tests/__init__.py | 1 + pymysql/tests/base.py | 18 +- pymysql/tests/test_DictCursor.py | 52 +- pymysql/tests/test_SSCursor.py | 102 ++- pymysql/tests/test_basic.py | 186 ++-- pymysql/tests/test_connection.py | 445 ++++++---- pymysql/tests/test_converters.py | 24 +- pymysql/tests/test_cursor.py | 74 +- pymysql/tests/test_err.py | 3 +- pymysql/tests/test_issues.py | 140 +-- pymysql/tests/test_load_local.py | 31 +- pymysql/tests/test_nextset.py | 12 +- pymysql/tests/test_optionfile.py | 7 +- pymysql/tests/thirdparty/__init__.py | 1 + .../tests/thirdparty/test_MySQLdb/__init__.py | 1 + .../thirdparty/test_MySQLdb/capabilities.py | 243 +++--- .../tests/thirdparty/test_MySQLdb/dbapi20.py | 794 +++++++++--------- .../test_MySQLdb/test_MySQLdb_capabilities.py | 73 +- .../test_MySQLdb/test_MySQLdb_dbapi20.py | 200 +++-- .../test_MySQLdb/test_MySQLdb_nonstandard.py | 46 +- pymysql/util.py | 1 - tests/test_auth.py | 42 +- tests/test_mariadb_auth.py | 5 +- 38 files changed, 2296 insertions(+), 1626 deletions(-) diff --git a/pymysql/__init__.py b/pymysql/__init__.py index 1e126dcd..5b49262e 100644 --- a/pymysql/__init__.py +++ b/pymysql/__init__.py @@ -26,12 +26,26 @@ from .constants import FIELD_TYPE from .converters import escape_dict, escape_sequence, escape_string from .err import ( - Warning, Error, InterfaceError, DataError, - DatabaseError, OperationalError, IntegrityError, InternalError, - NotSupportedError, ProgrammingError, MySQLError) + Warning, + Error, + InterfaceError, + DataError, + DatabaseError, + OperationalError, + IntegrityError, + InternalError, + NotSupportedError, + ProgrammingError, + MySQLError, +) from .times import ( - Date, Time, Timestamp, - DateFromTicks, TimeFromTicks, TimestampFromTicks) + Date, + Time, + Timestamp, + DateFromTicks, + TimeFromTicks, + TimestampFromTicks, +) VERSION = (0, 10, 1, None) @@ -45,7 +59,6 @@ class DBAPISet(frozenset): - def __ne__(self, other): if isinstance(other, set): return frozenset.__ne__(self, other) @@ -62,18 +75,32 @@ def __hash__(self): return frozenset.__hash__(self) -STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, - FIELD_TYPE.VAR_STRING]) -BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, - FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB]) -NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT, - FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, - FIELD_TYPE.TINY, FIELD_TYPE.YEAR]) -DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) -TIME = DBAPISet([FIELD_TYPE.TIME]) +STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]) +BINARY = DBAPISet( + [ + FIELD_TYPE.BLOB, + FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.MEDIUM_BLOB, + FIELD_TYPE.TINY_BLOB, + ] +) +NUMBER = DBAPISet( + [ + FIELD_TYPE.DECIMAL, + FIELD_TYPE.DOUBLE, + FIELD_TYPE.FLOAT, + FIELD_TYPE.INT24, + FIELD_TYPE.LONG, + FIELD_TYPE.LONGLONG, + FIELD_TYPE.TINY, + FIELD_TYPE.YEAR, + ] +) +DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) +TIME = DBAPISet([FIELD_TYPE.TIME]) TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME]) -DATETIME = TIMESTAMP -ROWID = DBAPISet() +DATETIME = TIMESTAMP +ROWID = DBAPISet() def Binary(x): @@ -87,9 +114,12 @@ def Connect(*args, **kwargs): more information. """ from .connections import Connection + return Connection(*args, **kwargs) + from . import connections as _orig_conn + if _orig_conn.Connection.__init__.__doc__ is not None: Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ del _orig_conn @@ -99,7 +129,8 @@ def get_client_info(): # for MySQLdb compatibility version = VERSION if VERSION[3] is None: version = VERSION[:3] - return '.'.join(map(str, version)) + return ".".join(map(str, version)) + connect = Connection = Connect @@ -110,9 +141,11 @@ def get_client_info(): # for MySQLdb compatibility __version__ = get_client_info() + def thread_safe(): return True # match MySQLdb.thread_safe() + def install_as_MySQLdb(): """ After this function is called, any application that imports MySQLdb or @@ -122,16 +155,50 @@ def install_as_MySQLdb(): __all__ = [ - 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date', - 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', - 'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError', - 'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER', - 'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError', - 'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect', - 'connections', 'constants', 'converters', 'cursors', - 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info', - 'paramstyle', 'threadsafety', 'version_info', - + "BINARY", + "Binary", + "Connect", + "Connection", + "DATE", + "Date", + "Time", + "Timestamp", + "DateFromTicks", + "TimeFromTicks", + "TimestampFromTicks", + "DataError", + "DatabaseError", + "Error", + "FIELD_TYPE", + "IntegrityError", + "InterfaceError", + "InternalError", + "MySQLError", + "NULL", + "NUMBER", + "NotSupportedError", + "DBAPISet", + "OperationalError", + "ProgrammingError", + "ROWID", + "STRING", + "TIME", + "TIMESTAMP", + "Warning", + "apilevel", + "connect", + "connections", + "constants", + "converters", + "cursors", + "escape_dict", + "escape_sequence", + "escape_string", + "get_client_info", + "paramstyle", + "threadsafety", + "version_info", "install_as_MySQLdb", - "NULL", "__version__", + "NULL", + "__version__", ] diff --git a/pymysql/_auth.py b/pymysql/_auth.py index 77caeafd..d16a0895 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -9,6 +9,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import padding + _have_cryptography = True except ImportError: _have_cryptography = False @@ -22,7 +23,7 @@ DEBUG = False SCRAMBLE_LENGTH = 20 -sha1_new = partial(hashlib.new, 'sha1') +sha1_new = partial(hashlib.new, "sha1") # mysql_native_password @@ -32,7 +33,7 @@ def scramble_native_password(password, message): """Scramble used for mysql_native_password""" if not password: - return b'' + return b"" stage1 = sha1_new(password).digest() stage2 = sha1_new(stage1).digest() @@ -59,7 +60,6 @@ def _my_crypt(message1, message2): class RandStruct_323: - def __init__(self, seed1, seed2): self.max_value = 0x3FFFFFFF self.seed1 = seed1 % self.max_value @@ -73,8 +73,10 @@ def my_rnd(self): def scramble_old_password(password, message): """Scramble for old_password""" - warnings.warn("old password (for MySQL <4.1) is used. Upgrade your password with newer auth method.\n" - "old password support will be removed in future PyMySQL version") + warnings.warn( + "old password (for MySQL <4.1) is used. Upgrade your password with newer auth method.\n" + "old password support will be removed in future PyMySQL version" + ) hash_pass = _hash_password_323(password) hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) hash_pass_n = struct.unpack(">LL", hash_pass) @@ -100,7 +102,7 @@ def _hash_password_323(password): nr2 = 0x12345671 # x in py3 is numbers, p27 is chars - for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]: + for c in [byte2int(x) for x in password if x not in (" ", "\t", 32, 9)]: nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF add = (add + c) & 0xFFFFFFFF @@ -120,9 +122,12 @@ def _init_nacl(): global _nacl_bindings try: from nacl import bindings + _nacl_bindings = bindings except ImportError: - raise RuntimeError("'pynacl' package is required for ed25519_password auth method") + raise RuntimeError( + "'pynacl' package is required for ed25519_password auth method" + ) def _scalar_clamp(s32): @@ -185,7 +190,7 @@ def _xor_password(password, salt): # See https://github.com/mysql/mysql-server/blob/7d10c82196c8e45554f27c00681474a9fb86d137/sql/auth/sha2_password.cc#L939-L945 salt = salt[:SCRAMBLE_LENGTH] password_bytes = bytearray(password) - #salt = bytearray(salt) # for PY2 compat. + # salt = bytearray(salt) # for PY2 compat. salt_len = len(salt) for i in range(len(password_bytes)): password_bytes[i] ^= salt[i % salt_len] @@ -198,8 +203,10 @@ def sha2_rsa_encrypt(password, salt, public_key): Used for sha256_password and caching_sha2_password. """ if not _have_cryptography: - raise RuntimeError("'cryptography' package is required for sha256_password or caching_sha2_password auth methods") - message = _xor_password(password + b'\0', salt) + raise RuntimeError( + "'cryptography' package is required for sha256_password or caching_sha2_password auth methods" + ) + message = _xor_password(password + b"\0", salt) rsa_key = serialization.load_pem_public_key(public_key, default_backend()) return rsa_key.encrypt( message, @@ -215,7 +222,7 @@ def sha256_password_auth(conn, pkt): if conn._secure: if DEBUG: print("sha256: Sending plain password") - data = conn.password + b'\0' + data = conn.password + b"\0" return _roundtrip(conn, data) if pkt.is_auth_switch_request(): @@ -224,12 +231,12 @@ def sha256_password_auth(conn, pkt): # Request server public key if DEBUG: print("sha256: Requesting server public key") - pkt = _roundtrip(conn, b'\1') + pkt = _roundtrip(conn, b"\1") if pkt.is_extra_auth_data(): conn.server_public_key = pkt._data[1:] if DEBUG: - print("Received public key:\n", conn.server_public_key.decode('ascii')) + print("Received public key:\n", conn.server_public_key.decode("ascii")) if conn.password: if not conn.server_public_key: @@ -237,7 +244,7 @@ def sha256_password_auth(conn, pkt): data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) else: - data = b'' + data = b"" return _roundtrip(conn, data) @@ -249,7 +256,7 @@ def scramble_caching_sha2(password, nonce): XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce)) """ if not password: - return b'' + return b"" p1 = hashlib.sha256(password).digest() p2 = hashlib.sha256(p1).digest() @@ -265,7 +272,7 @@ def scramble_caching_sha2(password, nonce): def caching_sha2_password_auth(conn, pkt): # No password fast path if not conn.password: - return _roundtrip(conn, b'') + return _roundtrip(conn, b"") if pkt.is_auth_switch_request(): # Try from fast auth @@ -305,10 +312,10 @@ def caching_sha2_password_auth(conn, pkt): if conn._secure: if DEBUG: print("caching sha2: Sending plain password via secure connection") - return _roundtrip(conn, conn.password + b'\0') + return _roundtrip(conn, conn.password + b"\0") if not conn.server_public_key: - pkt = _roundtrip(conn, b'\x02') # Request public key + pkt = _roundtrip(conn, b"\x02") # Request public key if not pkt.is_extra_auth_data(): raise OperationalError( "caching sha2: Unknown packet for public key: %s" % pkt._data[:1] @@ -316,7 +323,7 @@ def caching_sha2_password_auth(conn, pkt): conn.server_public_key = pkt._data[1:] if DEBUG: - print(conn.server_public_key.decode('ascii')) + print(conn.server_public_key.decode("ascii")) data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) pkt = _roundtrip(conn, data) diff --git a/pymysql/_socketio.py b/pymysql/_socketio.py index 6a11d42e..6b2d65a3 100644 --- a/pymysql/_socketio.py +++ b/pymysql/_socketio.py @@ -8,11 +8,12 @@ import io import errno -__all__ = ['SocketIO'] +__all__ = ["SocketIO"] EINTR = errno.EINTR _blocking_errnos = (errno.EAGAIN, errno.EWOULDBLOCK) + class SocketIO(io.RawIOBase): """Raw I/O implementation for stream sockets. @@ -85,29 +86,25 @@ def write(self, b): raise def readable(self): - """True if the SocketIO is open for reading. - """ + """True if the SocketIO is open for reading.""" if self.closed: raise ValueError("I/O operation on closed socket.") return self._reading def writable(self): - """True if the SocketIO is open for writing. - """ + """True if the SocketIO is open for writing.""" if self.closed: raise ValueError("I/O operation on closed socket.") return self._writing def seekable(self): - """True if the SocketIO is open for seeking. - """ + """True if the SocketIO is open for seeking.""" if self.closed: raise ValueError("I/O operation on closed socket.") return super().seekable() def fileno(self): - """Return the file descriptor of the underlying socket. - """ + """Return the file descriptor of the underlying socket.""" self._checkClosed() return self._sock.fileno() @@ -131,4 +128,3 @@ def close(self): io.RawIOBase.close(self) self._sock._decref_socketios() self._sock = None - diff --git a/pymysql/charset.py b/pymysql/charset.py index 3ef3ea46..ac87c53d 100644 --- a/pymysql/charset.py +++ b/pymysql/charset.py @@ -1,31 +1,29 @@ -MBLENGTH = { - 8:1, - 33:3, - 88:2, - 91:2 - } +MBLENGTH = {8: 1, 33: 3, 88: 2, 91: 2} class Charset: def __init__(self, id, name, collation, is_default): self.id, self.name, self.collation = id, name, collation - self.is_default = is_default == 'Yes' + self.is_default = is_default == "Yes" def __repr__(self): return "Charset(id=%s, name=%r, collation=%r)" % ( - self.id, self.name, self.collation) + self.id, + self.name, + self.collation, + ) @property def encoding(self): name = self.name - if name in ('utf8mb4', 'utf8mb3'): - return 'utf8' - if name == 'latin1': - return 'cp1252' - if name == 'koi8r': - return 'koi8_r' - if name == 'koi8u': - return 'koi8_u' + if name in ("utf8mb4", "utf8mb3"): + return "utf8" + if name == "latin1": + return "cp1252" + if name == "koi8r": + return "koi8_r" + if name == "koi8u": + return "koi8_u" return name @property @@ -49,6 +47,7 @@ def by_id(self, id): def by_name(self, name): return self._by_name.get(name.lower()) + _charsets = Charsets() """ Generated with: @@ -62,149 +61,149 @@ def by_name(self, name): " """ -_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes')) -_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', '')) -_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes')) -_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes')) -_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', '')) -_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes')) -_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes')) -_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes')) -_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes')) -_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes')) -_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes')) -_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes')) -_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes')) -_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', '')) -_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', '')) -_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes')) -_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes')) -_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes')) -_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', '')) -_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', '')) -_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes')) -_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', '')) -_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes')) -_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes')) -_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes')) -_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', '')) -_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes')) -_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', '')) -_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes')) -_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', '')) -_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes')) -_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes')) -_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', '')) -_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes')) -_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes')) -_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes')) -_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes')) -_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes')) -_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes')) -_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', '')) -_charsets.add(Charset(43, 'macce', 'macce_bin', '')) -_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', '')) -_charsets.add(Charset(45, 'utf8mb4', 'utf8mb4_general_ci', 'Yes')) -_charsets.add(Charset(46, 'utf8mb4', 'utf8mb4_bin', '')) -_charsets.add(Charset(47, 'latin1', 'latin1_bin', '')) -_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', '')) -_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', '')) -_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', '')) -_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes')) -_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', '')) -_charsets.add(Charset(53, 'macroman', 'macroman_bin', '')) -_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes')) -_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', '')) -_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes')) -_charsets.add(Charset(63, 'binary', 'binary', 'Yes')) -_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', '')) -_charsets.add(Charset(65, 'ascii', 'ascii_bin', '')) -_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', '')) -_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', '')) -_charsets.add(Charset(68, 'cp866', 'cp866_bin', '')) -_charsets.add(Charset(69, 'dec8', 'dec8_bin', '')) -_charsets.add(Charset(70, 'greek', 'greek_bin', '')) -_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', '')) -_charsets.add(Charset(72, 'hp8', 'hp8_bin', '')) -_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', '')) -_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', '')) -_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', '')) -_charsets.add(Charset(76, 'utf8', 'utf8_tolower_ci', '')) -_charsets.add(Charset(77, 'latin2', 'latin2_bin', '')) -_charsets.add(Charset(78, 'latin5', 'latin5_bin', '')) -_charsets.add(Charset(79, 'latin7', 'latin7_bin', '')) -_charsets.add(Charset(80, 'cp850', 'cp850_bin', '')) -_charsets.add(Charset(81, 'cp852', 'cp852_bin', '')) -_charsets.add(Charset(82, 'swe7', 'swe7_bin', '')) -_charsets.add(Charset(83, 'utf8', 'utf8_bin', '')) -_charsets.add(Charset(84, 'big5', 'big5_bin', '')) -_charsets.add(Charset(85, 'euckr', 'euckr_bin', '')) -_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', '')) -_charsets.add(Charset(87, 'gbk', 'gbk_bin', '')) -_charsets.add(Charset(88, 'sjis', 'sjis_bin', '')) -_charsets.add(Charset(89, 'tis620', 'tis620_bin', '')) -_charsets.add(Charset(91, 'ujis', 'ujis_bin', '')) -_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes')) -_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', '')) -_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', '')) -_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes')) -_charsets.add(Charset(96, 'cp932', 'cp932_bin', '')) -_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes')) -_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', '')) -_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', '')) -_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', '')) -_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', '')) -_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', '')) -_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', '')) -_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', '')) -_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', '')) -_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', '')) -_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', '')) -_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', '')) -_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', '')) -_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', '')) -_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', '')) -_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', '')) -_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', '')) -_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', '')) -_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', '')) -_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', '')) -_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', '')) -_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', '')) -_charsets.add(Charset(211, 'utf8', 'utf8_sinhala_ci', '')) -_charsets.add(Charset(212, 'utf8', 'utf8_german2_ci', '')) -_charsets.add(Charset(213, 'utf8', 'utf8_croatian_ci', '')) -_charsets.add(Charset(214, 'utf8', 'utf8_unicode_520_ci', '')) -_charsets.add(Charset(215, 'utf8', 'utf8_vietnamese_ci', '')) -_charsets.add(Charset(223, 'utf8', 'utf8_general_mysql500_ci', '')) -_charsets.add(Charset(224, 'utf8mb4', 'utf8mb4_unicode_ci', '')) -_charsets.add(Charset(225, 'utf8mb4', 'utf8mb4_icelandic_ci', '')) -_charsets.add(Charset(226, 'utf8mb4', 'utf8mb4_latvian_ci', '')) -_charsets.add(Charset(227, 'utf8mb4', 'utf8mb4_romanian_ci', '')) -_charsets.add(Charset(228, 'utf8mb4', 'utf8mb4_slovenian_ci', '')) -_charsets.add(Charset(229, 'utf8mb4', 'utf8mb4_polish_ci', '')) -_charsets.add(Charset(230, 'utf8mb4', 'utf8mb4_estonian_ci', '')) -_charsets.add(Charset(231, 'utf8mb4', 'utf8mb4_spanish_ci', '')) -_charsets.add(Charset(232, 'utf8mb4', 'utf8mb4_swedish_ci', '')) -_charsets.add(Charset(233, 'utf8mb4', 'utf8mb4_turkish_ci', '')) -_charsets.add(Charset(234, 'utf8mb4', 'utf8mb4_czech_ci', '')) -_charsets.add(Charset(235, 'utf8mb4', 'utf8mb4_danish_ci', '')) -_charsets.add(Charset(236, 'utf8mb4', 'utf8mb4_lithuanian_ci', '')) -_charsets.add(Charset(237, 'utf8mb4', 'utf8mb4_slovak_ci', '')) -_charsets.add(Charset(238, 'utf8mb4', 'utf8mb4_spanish2_ci', '')) -_charsets.add(Charset(239, 'utf8mb4', 'utf8mb4_roman_ci', '')) -_charsets.add(Charset(240, 'utf8mb4', 'utf8mb4_persian_ci', '')) -_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', '')) -_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', '')) -_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', '')) -_charsets.add(Charset(244, 'utf8mb4', 'utf8mb4_german2_ci', '')) -_charsets.add(Charset(245, 'utf8mb4', 'utf8mb4_croatian_ci', '')) -_charsets.add(Charset(246, 'utf8mb4', 'utf8mb4_unicode_520_ci', '')) -_charsets.add(Charset(247, 'utf8mb4', 'utf8mb4_vietnamese_ci', '')) -_charsets.add(Charset(248, 'gb18030', 'gb18030_chinese_ci', 'Yes')) -_charsets.add(Charset(249, 'gb18030', 'gb18030_bin', '')) -_charsets.add(Charset(250, 'gb18030', 'gb18030_unicode_520_ci', '')) -_charsets.add(Charset(255, 'utf8mb4', 'utf8mb4_0900_ai_ci', '')) +_charsets.add(Charset(1, "big5", "big5_chinese_ci", "Yes")) +_charsets.add(Charset(2, "latin2", "latin2_czech_cs", "")) +_charsets.add(Charset(3, "dec8", "dec8_swedish_ci", "Yes")) +_charsets.add(Charset(4, "cp850", "cp850_general_ci", "Yes")) +_charsets.add(Charset(5, "latin1", "latin1_german1_ci", "")) +_charsets.add(Charset(6, "hp8", "hp8_english_ci", "Yes")) +_charsets.add(Charset(7, "koi8r", "koi8r_general_ci", "Yes")) +_charsets.add(Charset(8, "latin1", "latin1_swedish_ci", "Yes")) +_charsets.add(Charset(9, "latin2", "latin2_general_ci", "Yes")) +_charsets.add(Charset(10, "swe7", "swe7_swedish_ci", "Yes")) +_charsets.add(Charset(11, "ascii", "ascii_general_ci", "Yes")) +_charsets.add(Charset(12, "ujis", "ujis_japanese_ci", "Yes")) +_charsets.add(Charset(13, "sjis", "sjis_japanese_ci", "Yes")) +_charsets.add(Charset(14, "cp1251", "cp1251_bulgarian_ci", "")) +_charsets.add(Charset(15, "latin1", "latin1_danish_ci", "")) +_charsets.add(Charset(16, "hebrew", "hebrew_general_ci", "Yes")) +_charsets.add(Charset(18, "tis620", "tis620_thai_ci", "Yes")) +_charsets.add(Charset(19, "euckr", "euckr_korean_ci", "Yes")) +_charsets.add(Charset(20, "latin7", "latin7_estonian_cs", "")) +_charsets.add(Charset(21, "latin2", "latin2_hungarian_ci", "")) +_charsets.add(Charset(22, "koi8u", "koi8u_general_ci", "Yes")) +_charsets.add(Charset(23, "cp1251", "cp1251_ukrainian_ci", "")) +_charsets.add(Charset(24, "gb2312", "gb2312_chinese_ci", "Yes")) +_charsets.add(Charset(25, "greek", "greek_general_ci", "Yes")) +_charsets.add(Charset(26, "cp1250", "cp1250_general_ci", "Yes")) +_charsets.add(Charset(27, "latin2", "latin2_croatian_ci", "")) +_charsets.add(Charset(28, "gbk", "gbk_chinese_ci", "Yes")) +_charsets.add(Charset(29, "cp1257", "cp1257_lithuanian_ci", "")) +_charsets.add(Charset(30, "latin5", "latin5_turkish_ci", "Yes")) +_charsets.add(Charset(31, "latin1", "latin1_german2_ci", "")) +_charsets.add(Charset(32, "armscii8", "armscii8_general_ci", "Yes")) +_charsets.add(Charset(33, "utf8", "utf8_general_ci", "Yes")) +_charsets.add(Charset(34, "cp1250", "cp1250_czech_cs", "")) +_charsets.add(Charset(36, "cp866", "cp866_general_ci", "Yes")) +_charsets.add(Charset(37, "keybcs2", "keybcs2_general_ci", "Yes")) +_charsets.add(Charset(38, "macce", "macce_general_ci", "Yes")) +_charsets.add(Charset(39, "macroman", "macroman_general_ci", "Yes")) +_charsets.add(Charset(40, "cp852", "cp852_general_ci", "Yes")) +_charsets.add(Charset(41, "latin7", "latin7_general_ci", "Yes")) +_charsets.add(Charset(42, "latin7", "latin7_general_cs", "")) +_charsets.add(Charset(43, "macce", "macce_bin", "")) +_charsets.add(Charset(44, "cp1250", "cp1250_croatian_ci", "")) +_charsets.add(Charset(45, "utf8mb4", "utf8mb4_general_ci", "Yes")) +_charsets.add(Charset(46, "utf8mb4", "utf8mb4_bin", "")) +_charsets.add(Charset(47, "latin1", "latin1_bin", "")) +_charsets.add(Charset(48, "latin1", "latin1_general_ci", "")) +_charsets.add(Charset(49, "latin1", "latin1_general_cs", "")) +_charsets.add(Charset(50, "cp1251", "cp1251_bin", "")) +_charsets.add(Charset(51, "cp1251", "cp1251_general_ci", "Yes")) +_charsets.add(Charset(52, "cp1251", "cp1251_general_cs", "")) +_charsets.add(Charset(53, "macroman", "macroman_bin", "")) +_charsets.add(Charset(57, "cp1256", "cp1256_general_ci", "Yes")) +_charsets.add(Charset(58, "cp1257", "cp1257_bin", "")) +_charsets.add(Charset(59, "cp1257", "cp1257_general_ci", "Yes")) +_charsets.add(Charset(63, "binary", "binary", "Yes")) +_charsets.add(Charset(64, "armscii8", "armscii8_bin", "")) +_charsets.add(Charset(65, "ascii", "ascii_bin", "")) +_charsets.add(Charset(66, "cp1250", "cp1250_bin", "")) +_charsets.add(Charset(67, "cp1256", "cp1256_bin", "")) +_charsets.add(Charset(68, "cp866", "cp866_bin", "")) +_charsets.add(Charset(69, "dec8", "dec8_bin", "")) +_charsets.add(Charset(70, "greek", "greek_bin", "")) +_charsets.add(Charset(71, "hebrew", "hebrew_bin", "")) +_charsets.add(Charset(72, "hp8", "hp8_bin", "")) +_charsets.add(Charset(73, "keybcs2", "keybcs2_bin", "")) +_charsets.add(Charset(74, "koi8r", "koi8r_bin", "")) +_charsets.add(Charset(75, "koi8u", "koi8u_bin", "")) +_charsets.add(Charset(76, "utf8", "utf8_tolower_ci", "")) +_charsets.add(Charset(77, "latin2", "latin2_bin", "")) +_charsets.add(Charset(78, "latin5", "latin5_bin", "")) +_charsets.add(Charset(79, "latin7", "latin7_bin", "")) +_charsets.add(Charset(80, "cp850", "cp850_bin", "")) +_charsets.add(Charset(81, "cp852", "cp852_bin", "")) +_charsets.add(Charset(82, "swe7", "swe7_bin", "")) +_charsets.add(Charset(83, "utf8", "utf8_bin", "")) +_charsets.add(Charset(84, "big5", "big5_bin", "")) +_charsets.add(Charset(85, "euckr", "euckr_bin", "")) +_charsets.add(Charset(86, "gb2312", "gb2312_bin", "")) +_charsets.add(Charset(87, "gbk", "gbk_bin", "")) +_charsets.add(Charset(88, "sjis", "sjis_bin", "")) +_charsets.add(Charset(89, "tis620", "tis620_bin", "")) +_charsets.add(Charset(91, "ujis", "ujis_bin", "")) +_charsets.add(Charset(92, "geostd8", "geostd8_general_ci", "Yes")) +_charsets.add(Charset(93, "geostd8", "geostd8_bin", "")) +_charsets.add(Charset(94, "latin1", "latin1_spanish_ci", "")) +_charsets.add(Charset(95, "cp932", "cp932_japanese_ci", "Yes")) +_charsets.add(Charset(96, "cp932", "cp932_bin", "")) +_charsets.add(Charset(97, "eucjpms", "eucjpms_japanese_ci", "Yes")) +_charsets.add(Charset(98, "eucjpms", "eucjpms_bin", "")) +_charsets.add(Charset(99, "cp1250", "cp1250_polish_ci", "")) +_charsets.add(Charset(192, "utf8", "utf8_unicode_ci", "")) +_charsets.add(Charset(193, "utf8", "utf8_icelandic_ci", "")) +_charsets.add(Charset(194, "utf8", "utf8_latvian_ci", "")) +_charsets.add(Charset(195, "utf8", "utf8_romanian_ci", "")) +_charsets.add(Charset(196, "utf8", "utf8_slovenian_ci", "")) +_charsets.add(Charset(197, "utf8", "utf8_polish_ci", "")) +_charsets.add(Charset(198, "utf8", "utf8_estonian_ci", "")) +_charsets.add(Charset(199, "utf8", "utf8_spanish_ci", "")) +_charsets.add(Charset(200, "utf8", "utf8_swedish_ci", "")) +_charsets.add(Charset(201, "utf8", "utf8_turkish_ci", "")) +_charsets.add(Charset(202, "utf8", "utf8_czech_ci", "")) +_charsets.add(Charset(203, "utf8", "utf8_danish_ci", "")) +_charsets.add(Charset(204, "utf8", "utf8_lithuanian_ci", "")) +_charsets.add(Charset(205, "utf8", "utf8_slovak_ci", "")) +_charsets.add(Charset(206, "utf8", "utf8_spanish2_ci", "")) +_charsets.add(Charset(207, "utf8", "utf8_roman_ci", "")) +_charsets.add(Charset(208, "utf8", "utf8_persian_ci", "")) +_charsets.add(Charset(209, "utf8", "utf8_esperanto_ci", "")) +_charsets.add(Charset(210, "utf8", "utf8_hungarian_ci", "")) +_charsets.add(Charset(211, "utf8", "utf8_sinhala_ci", "")) +_charsets.add(Charset(212, "utf8", "utf8_german2_ci", "")) +_charsets.add(Charset(213, "utf8", "utf8_croatian_ci", "")) +_charsets.add(Charset(214, "utf8", "utf8_unicode_520_ci", "")) +_charsets.add(Charset(215, "utf8", "utf8_vietnamese_ci", "")) +_charsets.add(Charset(223, "utf8", "utf8_general_mysql500_ci", "")) +_charsets.add(Charset(224, "utf8mb4", "utf8mb4_unicode_ci", "")) +_charsets.add(Charset(225, "utf8mb4", "utf8mb4_icelandic_ci", "")) +_charsets.add(Charset(226, "utf8mb4", "utf8mb4_latvian_ci", "")) +_charsets.add(Charset(227, "utf8mb4", "utf8mb4_romanian_ci", "")) +_charsets.add(Charset(228, "utf8mb4", "utf8mb4_slovenian_ci", "")) +_charsets.add(Charset(229, "utf8mb4", "utf8mb4_polish_ci", "")) +_charsets.add(Charset(230, "utf8mb4", "utf8mb4_estonian_ci", "")) +_charsets.add(Charset(231, "utf8mb4", "utf8mb4_spanish_ci", "")) +_charsets.add(Charset(232, "utf8mb4", "utf8mb4_swedish_ci", "")) +_charsets.add(Charset(233, "utf8mb4", "utf8mb4_turkish_ci", "")) +_charsets.add(Charset(234, "utf8mb4", "utf8mb4_czech_ci", "")) +_charsets.add(Charset(235, "utf8mb4", "utf8mb4_danish_ci", "")) +_charsets.add(Charset(236, "utf8mb4", "utf8mb4_lithuanian_ci", "")) +_charsets.add(Charset(237, "utf8mb4", "utf8mb4_slovak_ci", "")) +_charsets.add(Charset(238, "utf8mb4", "utf8mb4_spanish2_ci", "")) +_charsets.add(Charset(239, "utf8mb4", "utf8mb4_roman_ci", "")) +_charsets.add(Charset(240, "utf8mb4", "utf8mb4_persian_ci", "")) +_charsets.add(Charset(241, "utf8mb4", "utf8mb4_esperanto_ci", "")) +_charsets.add(Charset(242, "utf8mb4", "utf8mb4_hungarian_ci", "")) +_charsets.add(Charset(243, "utf8mb4", "utf8mb4_sinhala_ci", "")) +_charsets.add(Charset(244, "utf8mb4", "utf8mb4_german2_ci", "")) +_charsets.add(Charset(245, "utf8mb4", "utf8mb4_croatian_ci", "")) +_charsets.add(Charset(246, "utf8mb4", "utf8mb4_unicode_520_ci", "")) +_charsets.add(Charset(247, "utf8mb4", "utf8mb4_vietnamese_ci", "")) +_charsets.add(Charset(248, "gb18030", "gb18030_chinese_ci", "Yes")) +_charsets.add(Charset(249, "gb18030", "gb18030_bin", "")) +_charsets.add(Charset(250, "gb18030", "gb18030_unicode_520_ci", "")) +_charsets.add(Charset(255, "utf8mb4", "utf8mb4_0900_ai_ci", "")) charset_by_name = _charsets.by_name charset_by_id = _charsets.by_id diff --git a/pymysql/connections.py b/pymysql/connections.py index 6fd15e13..dc69868b 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -18,14 +18,19 @@ from .cursors import Cursor from .optionfile import Parser from .protocol import ( - dump_packet, MysqlPacket, FieldDescriptorPacket, OKPacketWrapper, - EOFPacketWrapper, LoadLocalPacketWrapper + dump_packet, + MysqlPacket, + FieldDescriptorPacket, + OKPacketWrapper, + EOFPacketWrapper, + LoadLocalPacketWrapper, ) from .util import byte2int, int2byte from . import err, VERSION_STRING try: import ssl + SSL_ENABLED = True except ImportError: ssl = None @@ -33,6 +38,7 @@ try: import getpass + DEFAULT_USER = getpass.getuser() del getpass except (ImportError, KeyError): @@ -43,8 +49,10 @@ _py_version = sys.version_info[:2] + def _fast_surrogateescape(s): - return s.decode('ascii', 'surrogateescape') + return s.decode("ascii", "surrogateescape") + def _makefile(sock, mode): return sock.makefile(mode) @@ -63,29 +71,34 @@ def _makefile(sock, mode): } -DEFAULT_CHARSET = 'utf8mb4' +DEFAULT_CHARSET = "utf8mb4" -MAX_PACKET_LEN = 2**24-1 +MAX_PACKET_LEN = 2 ** 24 - 1 def pack_int24(n): - return struct.pack(' 2: use_unicode = True @@ -184,7 +224,9 @@ def __init__(self, host=None, user=None, password="", password = passwd if compress or named_pipe: - raise NotImplementedError("compress and named_pipe arguments are not supported") + raise NotImplementedError( + "compress and named_pipe arguments are not supported" + ) self._local_infile = bool(local_infile) if self._local_infile: @@ -233,12 +275,14 @@ def _config(key, arg): ssl = { "ca": ssl_ca, "check_hostname": bool(ssl_verify_identity), - "verify_mode": ssl_verify_cert if ssl_verify_cert is not None else False, + "verify_mode": ssl_verify_cert + if ssl_verify_cert is not None + else False, } if ssl_cert is not None: ssl["cert"] = ssl_cert if ssl_key is not None: - ssl["key" ] = ssl_key + ssl["key"] = ssl_key if ssl: if not SSL_ENABLED: raise NotImplementedError("ssl module not found") @@ -253,7 +297,7 @@ def _config(key, arg): self.user = user or DEFAULT_USER self.password = password or b"" if isinstance(self.password, str): - self.password = self.password.encode('latin1') + self.password = self.password.encode("latin1") self.db = database self.unix_socket = unix_socket self.bind_address = bind_address @@ -307,9 +351,9 @@ def _config(key, arg): self.server_public_key = server_public_key self._connect_attrs = { - '_client_name': 'pymysql', - '_pid': str(os.getpid()), - '_client_version': VERSION_STRING, + "_client_name": "pymysql", + "_pid": str(os.getpid()), + "_client_version": VERSION_STRING, } if program_name: @@ -319,23 +363,23 @@ def _config(key, arg): self._sock = None else: self.connect() - + def __enter__(self): return self - + def __exit__(self, *exc_info): del exc_info self.close() - + def _create_ssl_ctx(self, sslp): if isinstance(sslp, ssl.SSLContext): return sslp - ca = sslp.get('ca') - capath = sslp.get('capath') + ca = sslp.get("ca") + capath = sslp.get("capath") hasnoca = ca is None and capath is None ctx = ssl.create_default_context(cafile=ca, capath=capath) - ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True) - verify_mode_value = sslp.get('verify_mode') + ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) + verify_mode_value = sslp.get("verify_mode") if verify_mode_value is None: ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED elif isinstance(verify_mode_value, bool): @@ -351,10 +395,10 @@ def _create_ssl_ctx(self, sslp): ctx.verify_mode = ssl.CERT_REQUIRED else: ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED - if 'cert' in sslp: - ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key')) - if 'cipher' in sslp: - ctx.set_ciphers(sslp['cipher']) + if "cert" in sslp: + ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key")) + if "cipher" in sslp: + ctx.set_ciphers(sslp["cipher"]) ctx.options |= ssl.OP_NO_SSLv2 ctx.options |= ssl.OP_NO_SSLv3 return ctx @@ -373,7 +417,7 @@ def close(self): self._closed = True if self._sock is None: return - send_data = struct.pack('= 5: + if int(self.server_version.split(".", 1)[0]) >= 5: self.client_flag |= CLIENT.MULTI_RESULTS if self.user is None: @@ -800,28 +851,30 @@ def _request_authentication(self): if isinstance(self.user, str): self.user = self.user.encode(self.encoding) - data_init = struct.pack('=5.0) - data += authresp + b'\0' + data += authresp + b"\0" if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: if isinstance(self.db, str): self.db = self.db.encode(self.encoding) - data += self.db + b'\0' + data += self.db + b"\0" if self.server_capabilities & CLIENT.PLUGIN_AUTH: - data += (plugin_name or b'') + b'\0' + data += (plugin_name or b"") + b"\0" if self.server_capabilities & CLIENT.CONNECT_ATTRS: - connect_attrs = b'' + connect_attrs = b"" for k, v in self._connect_attrs.items(): - k = k.encode('utf-8') - connect_attrs += struct.pack('B', len(k)) + k - v = v.encode('utf-8') - connect_attrs += struct.pack('B', len(v)) + v - data += struct.pack('B', len(connect_attrs)) + connect_attrs + k = k.encode("utf-8") + connect_attrs += struct.pack("B", len(k)) + k + v = v.encode("utf-8") + connect_attrs += struct.pack("B", len(v)) + v + data += struct.pack("B", len(connect_attrs)) + connect_attrs self.write_packet(data) auth_packet = self._read_packet() @@ -868,15 +921,19 @@ def _request_authentication(self): # if authentication method isn't accepted the first byte # will have the octet 254 if auth_packet.is_auth_switch_request(): - if DEBUG: print("received auth switch") + if DEBUG: + print("received auth switch") # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - auth_packet.read_uint8() # 0xfe packet identifier + auth_packet.read_uint8() # 0xfe packet identifier plugin_name = auth_packet.read_string() - if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None: + if ( + self.server_capabilities & CLIENT.PLUGIN_AUTH + and plugin_name is not None + ): auth_packet = self._process_auth(plugin_name, auth_packet) else: # send legacy handshake - data = _auth.scramble_old_password(self.password, self.salt) + b'\0' + data = _auth.scramble_old_password(self.password, self.salt) + b"\0" self.write_packet(data) auth_packet = self._read_packet() elif auth_packet.is_extra_auth_data(): @@ -888,9 +945,12 @@ def _request_authentication(self): elif self._auth_plugin_name == "sha256_password": auth_packet = _auth.sha256_password_auth(self, auth_packet) else: - raise err.OperationalError("Received extra packet for auth method %r", self._auth_plugin_name) + raise err.OperationalError( + "Received extra packet for auth method %r", self._auth_plugin_name + ) - if DEBUG: print("Succeed to auth") + if DEBUG: + print("Succeed to auth") def _process_auth(self, plugin_name, auth_packet): handler = self._get_auth_plugin_handler(plugin_name) @@ -898,22 +958,29 @@ def _process_auth(self, plugin_name, auth_packet): try: return handler.authenticate(auth_packet) except AttributeError: - if plugin_name != b'dialog': - raise err.OperationalError(2059, "Authentication plugin '%s'" - " not loaded: - %r missing authenticate method" % (plugin_name, type(handler))) + if plugin_name != b"dialog": + raise err.OperationalError( + 2059, + "Authentication plugin '%s'" + " not loaded: - %r missing authenticate method" + % (plugin_name, type(handler)), + ) if plugin_name == b"caching_sha2_password": return _auth.caching_sha2_password_auth(self, auth_packet) elif plugin_name == b"sha256_password": return _auth.sha256_password_auth(self, auth_packet) elif plugin_name == b"mysql_native_password": data = _auth.scramble_native_password(self.password, auth_packet.read_all()) - elif plugin_name == b'client_ed25519': + elif plugin_name == b"client_ed25519": data = _auth.ed25519_password(self.password, auth_packet.read_all()) elif plugin_name == b"mysql_old_password": - data = _auth.scramble_old_password(self.password, auth_packet.read_all()) + b'\0' + data = ( + _auth.scramble_old_password(self.password, auth_packet.read_all()) + + b"\0" + ) elif plugin_name == b"mysql_clear_password": # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html - data = self.password + b'\0' + data = self.password + b"\0" elif plugin_name == b"dialog": pkt = auth_packet while True: @@ -923,27 +990,41 @@ def _process_auth(self, plugin_name, auth_packet): prompt = pkt.read_all() if prompt == b"Password: ": - self.write_packet(self.password + b'\0') + self.write_packet(self.password + b"\0") elif handler: - resp = 'no response - TypeError within plugin.prompt method' + resp = "no response - TypeError within plugin.prompt method" try: resp = handler.prompt(echo, prompt) - self.write_packet(resp + b'\0') + self.write_packet(resp + b"\0") except AttributeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" \ - " not loaded: - %r missing prompt method" % (plugin_name, handler)) + raise err.OperationalError( + 2059, + "Authentication plugin '%s'" + " not loaded: - %r missing prompt method" + % (plugin_name, handler), + ) except TypeError: - raise err.OperationalError(2061, "Authentication plugin '%s'" \ - " %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt)) + raise err.OperationalError( + 2061, + "Authentication plugin '%s'" + " %r didn't respond with string. Returned '%r' to prompt %r" + % (plugin_name, handler, resp, prompt), + ) else: - raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler)) + raise err.OperationalError( + 2059, + "Authentication plugin '%s' (%r) not configured" + % (plugin_name, handler), + ) pkt = self._read_packet() pkt.check_error() if pkt.is_ok_packet() or last: break return pkt else: - raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name) + raise err.OperationalError( + 2059, "Authentication plugin '%s' not configured" % plugin_name + ) self.write_packet(data) pkt = self._read_packet() @@ -953,13 +1034,17 @@ def _process_auth(self, plugin_name, auth_packet): def _get_auth_plugin_handler(self, plugin_name): plugin_class = self._auth_plugin_map.get(plugin_name) if not plugin_class and isinstance(plugin_name, bytes): - plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) + plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii")) if plugin_class: try: handler = plugin_class(self) except TypeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" - " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) + raise err.OperationalError( + 2059, + "Authentication plugin '%s'" + " not loaded: - %r cannot be constructed with connection object" + % (plugin_name, plugin_class), + ) else: handler = None return handler @@ -982,24 +1067,24 @@ def _get_server_information(self): packet = self._read_packet() data = packet.get_all_data() - self.protocol_version = byte2int(data[i:i+1]) + self.protocol_version = byte2int(data[i : i + 1]) i += 1 - server_end = data.find(b'\0', i) - self.server_version = data[i:server_end].decode('latin1') + server_end = data.find(b"\0", i) + self.server_version = data[i:server_end].decode("latin1") i = server_end + 1 - self.server_thread_id = struct.unpack('= i + 6: - lang, stat, cap_h, salt_len = struct.unpack('= i + salt_len: # salt_len includes auth_plugin_data_part_1 and filler - self.salt += data[i:i+salt_len] + self.salt += data[i : i + salt_len] i += salt_len - i+=1 + i += 1 # AUTH PLUGIN NAME may appear here. if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: # Due to Bug#59453 the auth-plugin-name is missing the terminating @@ -1033,12 +1120,12 @@ def _get_server_information(self): # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake # didn't use version checks as mariadb is corrected and reports # earlier than those two. - server_end = data.find(b'\0', i) - if server_end < 0: # pragma: no cover - very specific upstream bug + server_end = data.find(b"\0", i) + if server_end < 0: # pragma: no cover - very specific upstream bug # not found \0 and last field so take it all - self._auth_plugin_name = data[i:].decode('utf-8') + self._auth_plugin_name = data[i:].decode("utf-8") else: - self._auth_plugin_name = data[i:server_end].decode('utf-8') + self._auth_plugin_name = data[i:server_end].decode("utf-8") def get_server_info(self): return self.server_version @@ -1056,7 +1143,6 @@ def get_server_info(self): class MySQLResult: - def __init__(self, connection): """ :type connection: Connection @@ -1127,7 +1213,8 @@ def _read_ok_packet(self, first_packet): def _read_load_local_packet(self, first_packet): if not self.connection._local_infile: raise RuntimeError( - "**WARN**: Received LOAD_LOCAL packet but local_infile option is false.") + "**WARN**: Received LOAD_LOCAL packet but local_infile option is false." + ) load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) try: @@ -1137,14 +1224,16 @@ def _read_load_local_packet(self, first_packet): raise ok_packet = self.connection._read_packet() - if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error + if ( + not ok_packet.is_ok_packet() + ): # pragma: no cover - upstream induced protocol error raise err.OperationalError(2014, "Commands Out of Sync") self._read_ok_packet(ok_packet) def _check_packet_is_eof(self, packet): if not packet.is_eof_packet(): return False - #TODO: Support CLIENT.DEPRECATE_EOF + # TODO: Support CLIENT.DEPRECATE_EOF # 1) Add DEPRECATE_EOF to CAPABILITIES # 2) Mask CAPABILITIES with server_capabilities # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper @@ -1211,7 +1300,8 @@ def _read_row_from_packet(self, packet): if data is not None: if encoding is not None: data = data.decode(encoding) - if DEBUG: print("DEBUG: DATA = ", data) + if DEBUG: + print("DEBUG: DATA = ", data) if converter is not None: data = converter(data) row.append(data) @@ -1246,17 +1336,18 @@ def _get_descriptions(self): encoding = conn_encoding else: # Integers, Dates and Times, and other basic data is encoded in ascii - encoding = 'ascii' + encoding = "ascii" else: encoding = None converter = self.connection.decoders.get(field_type) if converter is converters.through: converter = None - if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter)) + if DEBUG: + print("DEBUG: field={}, converter={}".format(field, converter)) self.converters.append((encoding, converter)) eof_packet = self.connection._read_packet() - assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF' + assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF" self.description = tuple(description) @@ -1268,19 +1359,23 @@ def __init__(self, filename, connection): def send_data(self): """Send data packets from the local file to the server""" if not self.connection._sock: - raise err.InterfaceError(0, '') + raise err.InterfaceError(0, "") conn = self.connection try: - with open(self.filename, 'rb') as open_file: - packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough + with open(self.filename, "rb") as open_file: + packet_size = min( + conn.max_allowed_packet, 16 * 1024 + ) # 16KB is efficient enough while True: chunk = open_file.read(packet_size) if not chunk: break conn.write_packet(chunk) except IOError: - raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename)) + raise err.OperationalError( + 1017, "Can't find file '{0}'".format(self.filename) + ) finally: # send the empty packet to signify we are done sending data - conn.write_packet(b'') + conn.write_packet(b"") diff --git a/pymysql/constants/CLIENT.py b/pymysql/constants/CLIENT.py index b42f1523..34fe57a5 100644 --- a/pymysql/constants/CLIENT.py +++ b/pymysql/constants/CLIENT.py @@ -21,9 +21,16 @@ CONNECT_ATTRS = 1 << 20 PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21 CAPABILITIES = ( - LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS - | SECURE_CONNECTION | MULTI_RESULTS - | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS) + LONG_PASSWORD + | LONG_FLAG + | PROTOCOL_41 + | TRANSACTIONS + | SECURE_CONNECTION + | MULTI_RESULTS + | PLUGIN_AUTH + | PLUGIN_AUTH_LENENC_CLIENT_DATA + | CONNECT_ATTRS +) # Not done yet HANDLE_EXPIRED_PASSWORDS = 1 << 22 diff --git a/pymysql/constants/COMMAND.py b/pymysql/constants/COMMAND.py index 1da27553..2d98850b 100644 --- a/pymysql/constants/COMMAND.py +++ b/pymysql/constants/COMMAND.py @@ -1,4 +1,3 @@ - COM_SLEEP = 0x00 COM_QUIT = 0x01 COM_INIT_DB = 0x02 @@ -9,12 +8,12 @@ COM_REFRESH = 0x07 COM_SHUTDOWN = 0x08 COM_STATISTICS = 0x09 -COM_PROCESS_INFO = 0x0a -COM_CONNECT = 0x0b -COM_PROCESS_KILL = 0x0c -COM_DEBUG = 0x0d -COM_PING = 0x0e -COM_TIME = 0x0f +COM_PROCESS_INFO = 0x0A +COM_CONNECT = 0x0B +COM_PROCESS_KILL = 0x0C +COM_DEBUG = 0x0D +COM_PING = 0x0E +COM_TIME = 0x0F COM_DELAYED_INSERT = 0x10 COM_CHANGE_USER = 0x11 COM_BINLOG_DUMP = 0x12 @@ -25,9 +24,9 @@ COM_STMT_EXECUTE = 0x17 COM_STMT_SEND_LONG_DATA = 0x18 COM_STMT_CLOSE = 0x19 -COM_STMT_RESET = 0x1a -COM_SET_OPTION = 0x1b -COM_STMT_FETCH = 0x1c -COM_DAEMON = 0x1d -COM_BINLOG_DUMP_GTID = 0x1e -COM_END = 0x1f +COM_STMT_RESET = 0x1A +COM_SET_OPTION = 0x1B +COM_STMT_FETCH = 0x1C +COM_DAEMON = 0x1D +COM_BINLOG_DUMP_GTID = 0x1E +COM_END = 0x1F diff --git a/pymysql/constants/CR.py b/pymysql/constants/CR.py index 48ca956e..25579a7c 100644 --- a/pymysql/constants/CR.py +++ b/pymysql/constants/CR.py @@ -1,68 +1,68 @@ # flake8: noqa # errmsg.h -CR_ERROR_FIRST = 2000 -CR_UNKNOWN_ERROR = 2000 -CR_SOCKET_CREATE_ERROR = 2001 -CR_CONNECTION_ERROR = 2002 -CR_CONN_HOST_ERROR = 2003 -CR_IPSOCK_ERROR = 2004 -CR_UNKNOWN_HOST = 2005 -CR_SERVER_GONE_ERROR = 2006 -CR_VERSION_ERROR = 2007 -CR_OUT_OF_MEMORY = 2008 -CR_WRONG_HOST_INFO = 2009 +CR_ERROR_FIRST = 2000 +CR_UNKNOWN_ERROR = 2000 +CR_SOCKET_CREATE_ERROR = 2001 +CR_CONNECTION_ERROR = 2002 +CR_CONN_HOST_ERROR = 2003 +CR_IPSOCK_ERROR = 2004 +CR_UNKNOWN_HOST = 2005 +CR_SERVER_GONE_ERROR = 2006 +CR_VERSION_ERROR = 2007 +CR_OUT_OF_MEMORY = 2008 +CR_WRONG_HOST_INFO = 2009 CR_LOCALHOST_CONNECTION = 2010 -CR_TCP_CONNECTION = 2011 +CR_TCP_CONNECTION = 2011 CR_SERVER_HANDSHAKE_ERR = 2012 -CR_SERVER_LOST = 2013 +CR_SERVER_LOST = 2013 CR_COMMANDS_OUT_OF_SYNC = 2014 CR_NAMEDPIPE_CONNECTION = 2015 -CR_NAMEDPIPEWAIT_ERROR = 2016 -CR_NAMEDPIPEOPEN_ERROR = 2017 +CR_NAMEDPIPEWAIT_ERROR = 2016 +CR_NAMEDPIPEOPEN_ERROR = 2017 CR_NAMEDPIPESETSTATE_ERROR = 2018 -CR_CANT_READ_CHARSET = 2019 +CR_CANT_READ_CHARSET = 2019 CR_NET_PACKET_TOO_LARGE = 2020 -CR_EMBEDDED_CONNECTION = 2021 -CR_PROBE_SLAVE_STATUS = 2022 -CR_PROBE_SLAVE_HOSTS = 2023 -CR_PROBE_SLAVE_CONNECT = 2024 +CR_EMBEDDED_CONNECTION = 2021 +CR_PROBE_SLAVE_STATUS = 2022 +CR_PROBE_SLAVE_HOSTS = 2023 +CR_PROBE_SLAVE_CONNECT = 2024 CR_PROBE_MASTER_CONNECT = 2025 CR_SSL_CONNECTION_ERROR = 2026 -CR_MALFORMED_PACKET = 2027 -CR_WRONG_LICENSE = 2028 +CR_MALFORMED_PACKET = 2027 +CR_WRONG_LICENSE = 2028 -CR_NULL_POINTER = 2029 -CR_NO_PREPARE_STMT = 2030 -CR_PARAMS_NOT_BOUND = 2031 -CR_DATA_TRUNCATED = 2032 +CR_NULL_POINTER = 2029 +CR_NO_PREPARE_STMT = 2030 +CR_PARAMS_NOT_BOUND = 2031 +CR_DATA_TRUNCATED = 2032 CR_NO_PARAMETERS_EXISTS = 2033 CR_INVALID_PARAMETER_NO = 2034 -CR_INVALID_BUFFER_USE = 2035 +CR_INVALID_BUFFER_USE = 2035 CR_UNSUPPORTED_PARAM_TYPE = 2036 -CR_SHARED_MEMORY_CONNECTION = 2037 -CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = 2038 -CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = 2039 +CR_SHARED_MEMORY_CONNECTION = 2037 +CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = 2038 +CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = 2039 CR_SHARED_MEMORY_CONNECT_FILE_MAP_ERROR = 2040 -CR_SHARED_MEMORY_CONNECT_MAP_ERROR = 2041 -CR_SHARED_MEMORY_FILE_MAP_ERROR = 2042 -CR_SHARED_MEMORY_MAP_ERROR = 2043 -CR_SHARED_MEMORY_EVENT_ERROR = 2044 +CR_SHARED_MEMORY_CONNECT_MAP_ERROR = 2041 +CR_SHARED_MEMORY_FILE_MAP_ERROR = 2042 +CR_SHARED_MEMORY_MAP_ERROR = 2043 +CR_SHARED_MEMORY_EVENT_ERROR = 2044 CR_SHARED_MEMORY_CONNECT_ABANDONED_ERROR = 2045 -CR_SHARED_MEMORY_CONNECT_SET_ERROR = 2046 -CR_CONN_UNKNOW_PROTOCOL = 2047 -CR_INVALID_CONN_HANDLE = 2048 -CR_SECURE_AUTH = 2049 -CR_FETCH_CANCELED = 2050 -CR_NO_DATA = 2051 -CR_NO_STMT_METADATA = 2052 -CR_NO_RESULT_SET = 2053 -CR_NOT_IMPLEMENTED = 2054 -CR_SERVER_LOST_EXTENDED = 2055 -CR_STMT_CLOSED = 2056 -CR_NEW_STMT_METADATA = 2057 -CR_ALREADY_CONNECTED = 2058 -CR_AUTH_PLUGIN_CANNOT_LOAD = 2059 -CR_DUPLICATE_CONNECTION_ATTR = 2060 -CR_AUTH_PLUGIN_ERR = 2061 +CR_SHARED_MEMORY_CONNECT_SET_ERROR = 2046 +CR_CONN_UNKNOW_PROTOCOL = 2047 +CR_INVALID_CONN_HANDLE = 2048 +CR_SECURE_AUTH = 2049 +CR_FETCH_CANCELED = 2050 +CR_NO_DATA = 2051 +CR_NO_STMT_METADATA = 2052 +CR_NO_RESULT_SET = 2053 +CR_NOT_IMPLEMENTED = 2054 +CR_SERVER_LOST_EXTENDED = 2055 +CR_STMT_CLOSED = 2056 +CR_NEW_STMT_METADATA = 2057 +CR_ALREADY_CONNECTED = 2058 +CR_AUTH_PLUGIN_CANNOT_LOAD = 2059 +CR_DUPLICATE_CONNECTION_ATTR = 2060 +CR_AUTH_PLUGIN_ERR = 2061 CR_ERROR_LAST = 2061 diff --git a/pymysql/constants/FIELD_TYPE.py b/pymysql/constants/FIELD_TYPE.py index 51bd5143..b8b44866 100644 --- a/pymysql/constants/FIELD_TYPE.py +++ b/pymysql/constants/FIELD_TYPE.py @@ -1,5 +1,3 @@ - - DECIMAL = 0 TINY = 1 SHORT = 2 diff --git a/pymysql/constants/SERVER_STATUS.py b/pymysql/constants/SERVER_STATUS.py index 6f5d5663..8f8d7768 100644 --- a/pymysql/constants/SERVER_STATUS.py +++ b/pymysql/constants/SERVER_STATUS.py @@ -1,4 +1,3 @@ - SERVER_STATUS_IN_TRANS = 1 SERVER_STATUS_AUTOCOMMIT = 2 SERVER_MORE_RESULTS_EXISTS = 8 diff --git a/pymysql/converters.py b/pymysql/converters.py index 6d1fc9ee..113dd298 100644 --- a/pymysql/converters.py +++ b/pymysql/converters.py @@ -25,6 +25,7 @@ def escape_item(val, charset, mapping=None): val = encoder(val, mapping) return val + def escape_dict(val, charset, mapping=None): n = {} for k, v in val.items(): @@ -32,6 +33,7 @@ def escape_dict(val, charset, mapping=None): n[k] = quoted return n + def escape_sequence(val, charset, mapping=None): n = [] for item in val: @@ -39,32 +41,38 @@ def escape_sequence(val, charset, mapping=None): n.append(quoted) return "(" + ",".join(n) + ")" + def escape_set(val, charset, mapping=None): - return ','.join([escape_item(x, charset, mapping) for x in val]) + return ",".join([escape_item(x, charset, mapping) for x in val]) + def escape_bool(value, mapping=None): return str(int(value)) + def escape_int(value, mapping=None): return str(value) + def escape_float(value, mapping=None): s = repr(value) - if s in ('inf', 'nan'): + if s in ("inf", "nan"): raise ProgrammingError("%s can not be used with MySQL" % s) - if 'e' not in s: - s += 'e0' + if "e" not in s: + s += "e0" return s + _escape_table = [chr(x) for x in range(128)] -_escape_table[0] = u'\\0' -_escape_table[ord('\\')] = u'\\\\' -_escape_table[ord('\n')] = u'\\n' -_escape_table[ord('\r')] = u'\\r' -_escape_table[ord('\032')] = u'\\Z' +_escape_table[0] = u"\\0" +_escape_table[ord("\\")] = u"\\\\" +_escape_table[ord("\n")] = u"\\n" +_escape_table[ord("\r")] = u"\\r" +_escape_table[ord("\032")] = u"\\Z" _escape_table[ord('"')] = u'\\"' _escape_table[ord("'")] = u"\\'" + def escape_string(value, mapping=None): """escapes *value* without adding quote. @@ -74,18 +82,22 @@ def escape_string(value, mapping=None): def escape_bytes_prefixed(value, mapping=None): - return "_binary'%s'" % value.decode('ascii', 'surrogateescape').translate(_escape_table) + return "_binary'%s'" % value.decode("ascii", "surrogateescape").translate( + _escape_table + ) def escape_bytes(value, mapping=None): - return "'%s'" % value.decode('ascii', 'surrogateescape').translate(_escape_table) + return "'%s'" % value.decode("ascii", "surrogateescape").translate(_escape_table) def escape_str(value, mapping=None): return "'%s'" % escape_string(str(value), mapping) + def escape_None(value, mapping=None): - return 'NULL' + return "NULL" + def escape_timedelta(obj, mapping=None): seconds = int(obj.seconds) % 60 @@ -97,6 +109,7 @@ def escape_timedelta(obj, mapping=None): fmt = "'{0:02d}:{1:02d}:{2:02d}'" return fmt.format(hours, minutes, seconds, obj.microseconds) + def escape_time(obj, mapping=None): if obj.microsecond: fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" @@ -104,6 +117,7 @@ def escape_time(obj, mapping=None): fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'" return fmt.format(obj) + def escape_datetime(obj, mapping=None): if obj.microsecond: fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" @@ -111,10 +125,12 @@ def escape_datetime(obj, mapping=None): fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'" return fmt.format(obj) + def escape_date(obj, mapping=None): fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'" return fmt.format(obj) + def escape_struct_time(obj, mapping=None): return escape_datetime(datetime.datetime(*obj[:6])) @@ -127,10 +143,13 @@ def _convert_second_fraction(s): if not s: return 0 # Pad zeros to ensure the fraction length in microseconds - s = s.ljust(6, '0') + s = s.ljust(6, "0") return int(s[:6]) -DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") + +DATETIME_RE = re.compile( + r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?" +) def convert_datetime(obj): @@ -150,7 +169,7 @@ def convert_datetime(obj): """ if isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + obj = obj.decode("ascii") m = DATETIME_RE.match(obj) if not m: @@ -159,10 +178,11 @@ def convert_datetime(obj): try: groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) - return datetime.datetime(*[ int(x) for x in groups ]) + return datetime.datetime(*[int(x) for x in groups]) except ValueError: return convert_date(obj) + TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") @@ -184,7 +204,7 @@ def convert_timedelta(obj): be parsed correctly by this function. """ if isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + obj = obj.decode("ascii") m = TIMEDELTA_RE.match(obj) if not m: @@ -196,16 +216,20 @@ def convert_timedelta(obj): negate = -1 if groups[0] else 1 hours, minutes, seconds, microseconds = groups[1:] - tdelta = datetime.timedelta( - hours = int(hours), - minutes = int(minutes), - seconds = int(seconds), - microseconds = int(microseconds) - ) * negate + tdelta = ( + datetime.timedelta( + hours=int(hours), + minutes=int(minutes), + seconds=int(seconds), + microseconds=int(microseconds), + ) + * negate + ) return tdelta except ValueError: return obj + TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") @@ -232,7 +256,7 @@ def convert_time(obj): use set this function as the converter for FIELD_TYPE.TIME. """ if isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + obj = obj.decode("ascii") m = TIME_RE.match(obj) if not m: @@ -242,8 +266,12 @@ def convert_time(obj): groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) hours, minutes, seconds, microseconds = groups - return datetime.time(hour=int(hours), minute=int(minutes), - second=int(seconds), microsecond=int(microseconds)) + return datetime.time( + hour=int(hours), + minute=int(minutes), + second=int(seconds), + microsecond=int(microseconds), + ) except ValueError: return obj @@ -263,9 +291,9 @@ def convert_date(obj): """ if isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + obj = obj.decode("ascii") try: - return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) + return datetime.date(*[int(x) for x in obj.split("-", 2)]) except ValueError: return obj @@ -274,7 +302,7 @@ def through(x): return x -#def convert_bit(b): +# def convert_bit(b): # b = "\x00" * (8 - len(b)) + b # pad w/ zeroes # return struct.unpack(">Q", b)[0] # diff --git a/pymysql/cursors.py b/pymysql/cursors.py index a8c52836..68ac78e7 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -6,10 +6,11 @@ #: executemany only supports simple bulk insert. #: You can use it to load large dataset. RE_INSERT_VALUES = re.compile( - r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + - r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + - r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", - re.IGNORECASE | re.DOTALL) + r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + re.IGNORECASE | re.DOTALL, +) class Cursor: @@ -167,16 +168,23 @@ def executemany(self, query, args): if m: q_prefix = m.group(1) % () q_values = m.group(2).rstrip() - q_postfix = m.group(3) or '' - assert q_values[0] == '(' and q_values[-1] == ')' - return self._do_execute_many(q_prefix, q_values, q_postfix, args, - self.max_stmt_length, - self._get_db().encoding) + q_postfix = m.group(3) or "" + assert q_values[0] == "(" and q_values[-1] == ")" + return self._do_execute_many( + q_prefix, + q_values, + q_postfix, + args, + self.max_stmt_length, + self._get_db().encoding, + ) self.rowcount = sum(self.execute(query, arg) for arg in args) return self.rowcount - def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): + def _do_execute_many( + self, prefix, values, postfix, args, max_stmt_length, encoding + ): conn = self._get_db() escape = self._escape_args if isinstance(prefix, str): @@ -187,18 +195,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encod args = iter(args) v = values % escape(next(args), conn) if isinstance(v, str): - v = v.encode(encoding, 'surrogateescape') + v = v.encode(encoding, "surrogateescape") sql += v rows = 0 for arg in args: v = values % escape(arg, conn) if isinstance(v, str): - v = v.encode(encoding, 'surrogateescape') + v = v.encode(encoding, "surrogateescape") if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: rows += self.execute(sql + postfix) sql = bytearray(prefix) else: - sql += b',' + sql += b"," sql += v rows += self.execute(sql + postfix) self.rowcount = rows @@ -234,14 +242,19 @@ def callproc(self, procname, args=()): """ conn = self._get_db() if args: - fmt = '@_{0}_%d=%s'.format(procname) - self._query('SET %s' % ','.join(fmt % (index, conn.escape(arg)) - for index, arg in enumerate(args))) + fmt = "@_{0}_%d=%s".format(procname) + self._query( + "SET %s" + % ",".join( + fmt % (index, conn.escape(arg)) for index, arg in enumerate(args) + ) + ) self.nextset() - q = "CALL %s(%s)" % (procname, - ','.join(['@_%s_%d' % (procname, i) - for i in range(len(args))])) + q = "CALL %s(%s)" % ( + procname, + ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]), + ) self._query(q) self._executed = q return args @@ -261,7 +274,7 @@ def fetchmany(self, size=None): if self._rows is None: return () end = self.rownumber + (size or self.arraysize) - result = self._rows[self.rownumber:end] + result = self._rows[self.rownumber : end] self.rownumber = min(end, len(self._rows)) return result @@ -271,17 +284,17 @@ def fetchall(self): if self._rows is None: return () if self.rownumber: - result = self._rows[self.rownumber:] + result = self._rows[self.rownumber :] else: result = self._rows self.rownumber = len(self._rows) return result - def scroll(self, value, mode='relative'): + def scroll(self, value, mode="relative"): self._check_executed() - if mode == 'relative': + if mode == "relative": r = self.rownumber + value - elif mode == 'absolute': + elif mode == "absolute": r = value else: raise err.ProgrammingError("unknown scroll mode %s" % mode) @@ -343,7 +356,7 @@ def _do_get_result(self): for f in self._result.fields: name = f.name if name in fields: - name = f.table_name + '.' + name + name = f.table_name + "." + name fields.append(name) self._fields = fields @@ -453,21 +466,23 @@ def fetchmany(self, size=None): self.rownumber += 1 return rows - def scroll(self, value, mode='relative'): + def scroll(self, value, mode="relative"): self._check_executed() - if mode == 'relative': + if mode == "relative": if value < 0: raise err.NotSupportedError( - "Backwards scrolling not supported by this cursor") + "Backwards scrolling not supported by this cursor" + ) for _ in range(value): self.read_next() self.rownumber += value - elif mode == 'absolute': + elif mode == "absolute": if value < self.rownumber: raise err.NotSupportedError( - "Backwards scrolling not supported by this cursor") + "Backwards scrolling not supported by this cursor" + ) end = value - self.rownumber for _ in range(end): diff --git a/pymysql/err.py b/pymysql/err.py index 94100cfe..3da5b166 100644 --- a/pymysql/err.py +++ b/pymysql/err.py @@ -74,33 +74,69 @@ def _map_error(exc, *errors): error_map[error] = exc -_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR, - ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME, - ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE, - ER.INVALID_GROUP_FUNC_USE, ER.UNSUPPORTED_EXTENSION, - ER.TABLE_MUST_HAVE_COLUMNS, ER.CANT_DO_THIS_DURING_AN_TRANSACTION, - ER.WRONG_DB_NAME, ER.WRONG_COLUMN_NAME, - ) -_map_error(DataError, ER.WARN_DATA_TRUNCATED, ER.WARN_NULL_TO_NOTNULL, - ER.WARN_DATA_OUT_OF_RANGE, ER.NO_DEFAULT, ER.PRIMARY_CANT_HAVE_NULL, - ER.DATA_TOO_LONG, ER.DATETIME_FUNCTION_OVERFLOW, ER.TRUNCATED_WRONG_VALUE_FOR_FIELD, - ER.ILLEGAL_VALUE_FOR_TYPE) -_map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW, - ER.NO_REFERENCED_ROW_2, ER.ROW_IS_REFERENCED, ER.ROW_IS_REFERENCED_2, - ER.CANNOT_ADD_FOREIGN, ER.BAD_NULL_ERROR) -_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK, - ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE) -_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR, - ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR, - ER.COLUMNACCESS_DENIED_ERROR, ER.CONSTRAINT_FAILED, ER.LOCK_DEADLOCK) +_map_error( + ProgrammingError, + ER.DB_CREATE_EXISTS, + ER.SYNTAX_ERROR, + ER.PARSE_ERROR, + ER.NO_SUCH_TABLE, + ER.WRONG_DB_NAME, + ER.WRONG_TABLE_NAME, + ER.FIELD_SPECIFIED_TWICE, + ER.INVALID_GROUP_FUNC_USE, + ER.UNSUPPORTED_EXTENSION, + ER.TABLE_MUST_HAVE_COLUMNS, + ER.CANT_DO_THIS_DURING_AN_TRANSACTION, + ER.WRONG_DB_NAME, + ER.WRONG_COLUMN_NAME, +) +_map_error( + DataError, + ER.WARN_DATA_TRUNCATED, + ER.WARN_NULL_TO_NOTNULL, + ER.WARN_DATA_OUT_OF_RANGE, + ER.NO_DEFAULT, + ER.PRIMARY_CANT_HAVE_NULL, + ER.DATA_TOO_LONG, + ER.DATETIME_FUNCTION_OVERFLOW, + ER.TRUNCATED_WRONG_VALUE_FOR_FIELD, + ER.ILLEGAL_VALUE_FOR_TYPE, +) +_map_error( + IntegrityError, + ER.DUP_ENTRY, + ER.NO_REFERENCED_ROW, + ER.NO_REFERENCED_ROW_2, + ER.ROW_IS_REFERENCED, + ER.ROW_IS_REFERENCED_2, + ER.CANNOT_ADD_FOREIGN, + ER.BAD_NULL_ERROR, +) +_map_error( + NotSupportedError, + ER.WARNING_NOT_COMPLETE_ROLLBACK, + ER.NOT_SUPPORTED_YET, + ER.FEATURE_DISABLED, + ER.UNKNOWN_STORAGE_ENGINE, +) +_map_error( + OperationalError, + ER.DBACCESS_DENIED_ERROR, + ER.ACCESS_DENIED_ERROR, + ER.CON_COUNT_ERROR, + ER.TABLEACCESS_DENIED_ERROR, + ER.COLUMNACCESS_DENIED_ERROR, + ER.CONSTRAINT_FAILED, + ER.LOCK_DEADLOCK, +) del _map_error, ER def raise_mysql_exception(data): - errno = struct.unpack('= 2 and value[0] == value[-1] == quote: return value[1:-1] diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 541475ad..24b3f23e 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -25,7 +25,7 @@ def printable(data): if isinstance(data, int): return chr(data) return data - return '.' + return "." try: print("packet length:", len(data)) @@ -35,11 +35,14 @@ def printable(data): print("-" * 66) except ValueError: pass - dump_data = [data[i:i+16] for i in range(0, min(len(data), 256), 16)] + dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)] for d in dump_data: - print(' '.join("{:02X}".format(byte2int(x)) for x in d) + - ' ' * (16 - len(d)) + ' ' * 2 + - ''.join(printable(x) for x in d)) + print( + " ".join("{:02X}".format(byte2int(x)) for x in d) + + " " * (16 - len(d)) + + " " * 2 + + "".join(printable(x) for x in d) + ) print("-" * 66) print() @@ -49,7 +52,8 @@ class MysqlPacket: Provides an interface for reading/parsing the packet results. """ - __slots__ = ('_position', '_data') + + __slots__ = ("_position", "_data") def __init__(self, data, encoding): self._position = 0 @@ -60,11 +64,13 @@ def get_all_data(self): def read(self, size): """Read the first 'size' bytes in packet and advance cursor past them.""" - result = self._data[self._position:(self._position+size)] + result = self._data[self._position : (self._position + size)] if len(result) != size: - error = ('Result length not requested length:\n' - 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' - % (size, len(result), self._position, len(self._data))) + error = ( + "Result length not requested length:\n" + "Expected=%s. Actual=%s. Position: %s. Data Length: %s" + % (size, len(result), self._position, len(self._data)) + ) if DEBUG: print(error) self.dump() @@ -77,7 +83,7 @@ def read_all(self): (Subsequent read() will return errors.) """ - result = self._data[self._position:] + result = self._data[self._position :] self._position = None # ensure no subsequent read() return result @@ -85,8 +91,10 @@ def advance(self, length): """Advance the cursor in data buffer 'length' bytes.""" new_position = self._position + length if new_position < 0 or new_position > len(self._data): - raise Exception('Invalid advance amount (%s) for cursor. ' - 'Position=%s' % (length, new_position)) + raise Exception( + "Invalid advance amount (%s) for cursor. " + "Position=%s" % (length, new_position) + ) self._position = new_position def rewind(self, position=0): @@ -104,7 +112,7 @@ def get_bytes(self, position, length=1): No error checking is done. If requesting outside end of buffer an empty string (or string shorter than 'length') may be returned! """ - return self._data[position:(position+length)] + return self._data[position : (position + length)] def read_uint8(self): result = self._data[self._position] @@ -112,30 +120,30 @@ def read_uint8(self): return result def read_uint16(self): - result = struct.unpack_from('= 7 + return self._data[0:1] == b"\0" and len(self._data) >= 7 def is_eof_packet(self): # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet # Caution: \xFE may be LengthEncodedInteger. # If \xFE is LengthEncodedInteger header, 8bytes followed. - return self._data[0:1] == b'\xfe' and len(self._data) < 9 + return self._data[0:1] == b"\xfe" and len(self._data) < 9 def is_auth_switch_request(self): # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - return self._data[0:1] == b'\xfe' + return self._data[0:1] == b"\xfe" def is_extra_auth_data(self): # https://dev.mysql.com/doc/internals/en/successful-authentication.html - return self._data[0:1] == b'\x01' + return self._data[0:1] == b"\x01" def is_resultset_packet(self): field_count = ord(self._data[0:1]) return 1 <= field_count <= 250 def is_load_local_packet(self): - return self._data[0:1] == b'\xfb' + return self._data[0:1] == b"\xfb" def is_error_packet(self): - return self._data[0:1] == b'\xff' + return self._data[0:1] == b"\xff" def check_error(self): if self.is_error_packet(): @@ -211,7 +219,8 @@ def raise_for_error(self): self.rewind() self.advance(1) # field_count == error (we already know that) errno = self.read_uint16() - if DEBUG: print("errno =", errno) + if DEBUG: + print("errno =", errno) err.raise_mysql_exception(self._data) def dump(self): @@ -240,8 +249,13 @@ def _parse_field_descriptor(self, encoding): self.org_table = self.read_length_coded_string().decode(encoding) self.name = self.read_length_coded_string().decode(encoding) self.org_name = self.read_length_coded_string().decode(encoding) - self.charsetnr, self.length, self.type_code, self.flags, self.scale = ( - self.read_struct('= version_tuple @@ -53,10 +59,12 @@ def connect(self, **params): p = self.databases[0].copy() p.update(params) conn = pymysql.connect(**p) + @self.addCleanup def teardown(): if conn.open: conn.close() + return conn def _teardown_connections(self): diff --git a/pymysql/tests/test_DictCursor.py b/pymysql/tests/test_DictCursor.py index 122882e6..581a0c4a 100644 --- a/pymysql/tests/test_DictCursor.py +++ b/pymysql/tests/test_DictCursor.py @@ -6,9 +6,9 @@ class TestDictCursor(base.PyMySQLTestCase): - bob = {'name': 'bob', 'age': 21, 'DOB': datetime.datetime(1990, 2, 6, 23, 4, 56)} - jim = {'name': 'jim', 'age': 56, 'DOB': datetime.datetime(1955, 5, 9, 13, 12, 45)} - fred = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)} + bob = {"name": "bob", "age": 21, "DOB": datetime.datetime(1990, 2, 6, 23, 4, 56)} + jim = {"name": "jim", "age": 56, "DOB": datetime.datetime(1955, 5, 9, 13, 12, 45)} + fred = {"name": "fred", "age": 100, "DOB": datetime.datetime(1911, 9, 12, 1, 1, 1)} cursor_type = pymysql.cursors.DictCursor @@ -23,10 +23,14 @@ def setUp(self): c.execute("drop table if exists dictcursor") # include in filterwarnings since for unbuffered dict cursor warning for lack of table # will only be propagated at start of next execute() call - c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""") - data = [("bob", 21, "1990-02-06 23:04:56"), - ("jim", 56, "1955-05-09 13:12:45"), - ("fred", 100, "1911-09-12 01:01:01")] + c.execute( + """CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""" + ) + data = [ + ("bob", 21, "1990-02-06 23:04:56"), + ("jim", 56, "1955-05-09 13:12:45"), + ("fred", 100, "1911-09-12 01:01:01"), + ] c.executemany("insert into dictcursor values (%s,%s,%s)", data) def tearDown(self): @@ -39,13 +43,13 @@ def _ensure_cursor_expired(self, cursor): def test_DictCursor(self): bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy() - #all assert test compare to the structure as would come out from MySQLdb + # all assert test compare to the structure as would come out from MySQLdb conn = self.conn c = conn.cursor(self.cursor_type) # try an update which should return no rows c.execute("update dictcursor set age=20 where name='bob'") - bob['age'] = 20 + bob["age"] = 20 # pull back the single row dict for bob and check c.execute("SELECT * from dictcursor where name='bob'") r = c.fetchone() @@ -55,19 +59,23 @@ def test_DictCursor(self): # same again, but via fetchall => tuple) c.execute("SELECT * from dictcursor where name='bob'") r = c.fetchall() - self.assertEqual([bob], r, "fetch a 1 row result via fetchall failed via DictCursor") + self.assertEqual( + [bob], r, "fetch a 1 row result via fetchall failed via DictCursor" + ) # same test again but iterate over the c.execute("SELECT * from dictcursor where name='bob'") for r in c: - self.assertEqual(bob, r, "fetch a 1 row result via iteration failed via DictCursor") + self.assertEqual( + bob, r, "fetch a 1 row result via iteration failed via DictCursor" + ) # get all 3 row via fetchall c.execute("SELECT * from dictcursor") r = c.fetchall() - self.assertEqual([bob,jim,fred], r, "fetchall failed via DictCursor") - #same test again but do a list comprehension + self.assertEqual([bob, jim, fred], r, "fetchall failed via DictCursor") + # same test again but do a list comprehension c.execute("SELECT * from dictcursor") r = list(c) - self.assertEqual([bob,jim,fred], r, "DictCursor should be iterable") + self.assertEqual([bob, jim, fred], r, "DictCursor should be iterable") # get all 2 row via fetchmany c.execute("SELECT * from dictcursor") r = c.fetchmany(2) @@ -75,12 +83,13 @@ def test_DictCursor(self): self._ensure_cursor_expired(c) def test_custom_dict(self): - class MyDict(dict): pass + class MyDict(dict): + pass class MyDictCursor(self.cursor_type): dict_type = MyDict - keys = ['name', 'age', 'DOB'] + keys = ["name", "age", "DOB"] bob = MyDict([(k, self.bob[k]) for k in keys]) jim = MyDict([(k, self.jim[k]) for k in keys]) fred = MyDict([(k, self.fred[k]) for k in keys]) @@ -93,18 +102,15 @@ class MyDictCursor(self.cursor_type): cur.execute("SELECT * FROM dictcursor") r = cur.fetchall() - self.assertEqual([bob, jim, fred], r, - "fetchall failed via MyDictCursor") + self.assertEqual([bob, jim, fred], r, "fetchall failed via MyDictCursor") cur.execute("SELECT * FROM dictcursor") r = list(cur) - self.assertEqual([bob, jim, fred], r, - "list failed via MyDictCursor") + self.assertEqual([bob, jim, fred], r, "list failed via MyDictCursor") cur.execute("SELECT * FROM dictcursor") r = cur.fetchmany(2) - self.assertEqual([bob, jim], r, - "list failed via MyDictCursor") + self.assertEqual([bob, jim], r, "list failed via MyDictCursor") self._ensure_cursor_expired(cur) @@ -114,6 +120,8 @@ class TestSSDictCursor(TestDictCursor): def _ensure_cursor_expired(self, cursor): list(cursor.fetchall_unbuffered()) + if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/test_SSCursor.py b/pymysql/tests/test_SSCursor.py index 2b0de78a..a68a7769 100644 --- a/pymysql/tests/test_SSCursor.py +++ b/pymysql/tests/test_SSCursor.py @@ -6,7 +6,7 @@ from pymysql.constants import CLIENT except Exception: # For local testing from top-level directory, without installing - sys.path.append('../pymysql') + sys.path.append("../pymysql") from pymysql.tests import base import pymysql.cursors from pymysql.constants import CLIENT @@ -18,35 +18,38 @@ def test_SSCursor(self): conn = self.connect(client_flag=CLIENT.MULTI_STATEMENTS) data = [ - ('America', '', 'America/Jamaica'), - ('America', '', 'America/Los_Angeles'), - ('America', '', 'America/Lima'), - ('America', '', 'America/New_York'), - ('America', '', 'America/Menominee'), - ('America', '', 'America/Havana'), - ('America', '', 'America/El_Salvador'), - ('America', '', 'America/Costa_Rica'), - ('America', '', 'America/Denver'), - ('America', '', 'America/Detroit'),] + ("America", "", "America/Jamaica"), + ("America", "", "America/Los_Angeles"), + ("America", "", "America/Lima"), + ("America", "", "America/New_York"), + ("America", "", "America/Menominee"), + ("America", "", "America/Havana"), + ("America", "", "America/El_Salvador"), + ("America", "", "America/Costa_Rica"), + ("America", "", "America/Denver"), + ("America", "", "America/Detroit"), + ] cursor = conn.cursor(pymysql.cursors.SSCursor) # Create table - cursor.execute('CREATE TABLE tz_data (' - 'region VARCHAR(64),' - 'zone VARCHAR(64),' - 'name VARCHAR(64))') + cursor.execute( + "CREATE TABLE tz_data (" + "region VARCHAR(64)," + "zone VARCHAR(64)," + "name VARCHAR(64))" + ) conn.begin() # Test INSERT for i in data: - cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i) - self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match') + cursor.execute("INSERT INTO tz_data VALUES (%s, %s, %s)", i) + self.assertEqual(conn.affected_rows(), 1, "affected_rows does not match") conn.commit() # Test fetchone() iter = 0 - cursor.execute('SELECT * FROM tz_data') + cursor.execute("SELECT * FROM tz_data") while True: row = cursor.fetchone() if row is None: @@ -54,26 +57,35 @@ def test_SSCursor(self): iter += 1 # Test cursor.rowcount - self.assertEqual(cursor.rowcount, affected_rows, - 'cursor.rowcount != %s' % (str(affected_rows))) + self.assertEqual( + cursor.rowcount, + affected_rows, + "cursor.rowcount != %s" % (str(affected_rows)), + ) # Test cursor.rownumber - self.assertEqual(cursor.rownumber, iter, - 'cursor.rowcount != %s' % (str(iter))) + self.assertEqual( + cursor.rownumber, iter, "cursor.rowcount != %s" % (str(iter)) + ) # Test row came out the same as it went in - self.assertEqual((row in data), True, - 'Row not found in source data') + self.assertEqual((row in data), True, "Row not found in source data") # Test fetchall - cursor.execute('SELECT * FROM tz_data') - self.assertEqual(len(cursor.fetchall()), len(data), - 'fetchall failed. Number of rows does not match') + cursor.execute("SELECT * FROM tz_data") + self.assertEqual( + len(cursor.fetchall()), + len(data), + "fetchall failed. Number of rows does not match", + ) # Test fetchmany - cursor.execute('SELECT * FROM tz_data') - self.assertEqual(len(cursor.fetchmany(2)), 2, - 'fetchmany failed. Number of rows does not match') + cursor.execute("SELECT * FROM tz_data") + self.assertEqual( + len(cursor.fetchmany(2)), + 2, + "fetchmany failed. Number of rows does not match", + ) # So MySQLdb won't throw "Commands out of sync" while True: @@ -82,30 +94,38 @@ def test_SSCursor(self): break # Test update, affected_rows() - cursor.execute('UPDATE tz_data SET zone = %s', ['Foo']) + cursor.execute("UPDATE tz_data SET zone = %s", ["Foo"]) conn.commit() - self.assertEqual(cursor.rowcount, len(data), - 'Update failed. affected_rows != %s' % (str(len(data)))) + self.assertEqual( + cursor.rowcount, + len(data), + "Update failed. affected_rows != %s" % (str(len(data))), + ) # Test executemany - cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data) - self.assertEqual(cursor.rowcount, len(data), - 'executemany failed. cursor.rowcount != %s' % (str(len(data)))) + cursor.executemany("INSERT INTO tz_data VALUES (%s, %s, %s)", data) + self.assertEqual( + cursor.rowcount, + len(data), + "executemany failed. cursor.rowcount != %s" % (str(len(data))), + ) # Test multiple datasets - cursor.execute('SELECT 1; SELECT 2; SELECT 3') - self.assertListEqual(list(cursor), [(1, )]) + cursor.execute("SELECT 1; SELECT 2; SELECT 3") + self.assertListEqual(list(cursor), [(1,)]) self.assertTrue(cursor.nextset()) - self.assertListEqual(list(cursor), [(2, )]) + self.assertListEqual(list(cursor), [(2,)]) self.assertTrue(cursor.nextset()) - self.assertListEqual(list(cursor), [(3, )]) + self.assertListEqual(list(cursor), [(3,)]) self.assertFalse(cursor.nextset()) - cursor.execute('DROP TABLE IF EXISTS tz_data') + cursor.execute("DROP TABLE IF EXISTS tz_data") cursor.close() + __all__ = ["TestSSCursor"] if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/test_basic.py b/pymysql/tests/test_basic.py index 840c4860..f8e622e6 100644 --- a/pymysql/tests/test_basic.py +++ b/pymysql/tests/test_basic.py @@ -18,23 +18,46 @@ def test_datatypes(self): """ test every data type """ conn = self.connect() c = conn.cursor() - c.execute("create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)") + c.execute( + "create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)" + ) try: # insert values - v = (True, -3, 123456789012, 5.7, "hello'\" world", u"Espa\xc3\xb1ol", "binary\x00data".encode(conn.encoding), datetime.date(1988,2,2), datetime.datetime(2014, 5, 15, 7, 45, 57), datetime.timedelta(5,6), datetime.time(16,32), time.localtime()) - c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v) + v = ( + True, + -3, + 123456789012, + 5.7, + "hello'\" world", + u"Espa\xc3\xb1ol", + "binary\x00data".encode(conn.encoding), + datetime.date(1988, 2, 2), + datetime.datetime(2014, 5, 15, 7, 45, 57), + datetime.timedelta(5, 6), + datetime.time(16, 32), + time.localtime(), + ) + c.execute( + "insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", + v, + ) c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") r = c.fetchone() self.assertEqual(util.int2byte(1), r[0]) self.assertEqual(v[1:10], r[1:10]) - self.assertEqual(datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10]) + self.assertEqual( + datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10] + ) self.assertEqual(datetime.datetime(*v[-1][:6]), r[-1]) c.execute("delete from test_datatypes") # check nulls - c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", [None] * 12) + c.execute( + "insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", + [None] * 12, + ) c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") r = c.fetchone() self.assertEqual(tuple([None] * 12), r) @@ -43,11 +66,15 @@ def test_datatypes(self): # check sequences type for seq_type in (tuple, list, set, frozenset): - c.execute("insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)") - seq = seq_type([2,6]) - c.execute("select l from test_datatypes where i in %s order by i", (seq,)) + c.execute( + "insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)" + ) + seq = seq_type([2, 6]) + c.execute( + "select l from test_datatypes where i in %s order by i", (seq,) + ) r = c.fetchall() - self.assertEqual(((4,),(8,)), r) + self.assertEqual(((4,), (8,)), r) c.execute("delete from test_datatypes") finally: @@ -59,9 +86,12 @@ def test_dict(self): c = conn.cursor() c.execute("create table test_dict (a integer, b integer, c integer)") try: - c.execute("insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", {"a":1,"b":2,"c":3}) + c.execute( + "insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", + {"a": 1, "b": 2, "c": 3}, + ) c.execute("select a,b,c from test_dict") - self.assertEqual((1,2,3), c.fetchone()) + self.assertEqual((1, 2, 3), c.fetchone()) finally: c.execute("drop table test_dict") @@ -94,7 +124,8 @@ def test_binary(self): data = bytes(bytearray(range(255))) conn = self.connect() self.safe_create_table( - conn, "test_binary", "create table test_binary (b binary(255))") + conn, "test_binary", "create table test_binary (b binary(255))" + ) with conn.cursor() as c: c.execute("insert into test_binary (b) values (_binary %s)", (data,)) @@ -105,8 +136,7 @@ def test_blob(self): """test blob data""" data = bytes(bytearray(range(256)) * 4) conn = self.connect() - self.safe_create_table( - conn, "test_blob", "create table test_blob (b blob)") + self.safe_create_table(conn, "test_blob", "create table test_blob (b blob)") with conn.cursor() as c: c.execute("insert into test_blob (b) values (_binary %s)", (data,)) @@ -118,23 +148,29 @@ def test_untyped(self): conn = self.connect() c = conn.cursor() c.execute("select null,''") - self.assertEqual((None,u''), c.fetchone()) + self.assertEqual((None, u""), c.fetchone()) c.execute("select '',null") - self.assertEqual((u'',None), c.fetchone()) + self.assertEqual((u"", None), c.fetchone()) def test_timedelta(self): """ test timedelta conversion """ conn = self.connect() c = conn.cursor() - c.execute("select time('12:30'), time('23:12:59'), time('23:12:59.05100'), time('-12:30'), time('-23:12:59'), time('-23:12:59.05100'), time('-00:30')") - self.assertEqual((datetime.timedelta(0, 45000), - datetime.timedelta(0, 83579), - datetime.timedelta(0, 83579, 51000), - -datetime.timedelta(0, 45000), - -datetime.timedelta(0, 83579), - -datetime.timedelta(0, 83579, 51000), - -datetime.timedelta(0, 1800)), - c.fetchone()) + c.execute( + "select time('12:30'), time('23:12:59'), time('23:12:59.05100'), time('-12:30'), time('-23:12:59'), time('-23:12:59.05100'), time('-00:30')" + ) + self.assertEqual( + ( + datetime.timedelta(0, 45000), + datetime.timedelta(0, 83579), + datetime.timedelta(0, 83579, 51000), + -datetime.timedelta(0, 45000), + -datetime.timedelta(0, 83579), + -datetime.timedelta(0, 83579, 51000), + -datetime.timedelta(0, 1800), + ), + c.fetchone(), + ) def test_datetime_microseconds(self): """ test datetime conversion w microseconds""" @@ -146,10 +182,7 @@ def test_datetime_microseconds(self): dt = datetime.datetime(2013, 11, 12, 9, 9, 9, 123450) c.execute("create table test_datetime (id int, ts datetime(6))") try: - c.execute( - "insert into test_datetime values (%s, %s)", - (1, dt) - ) + c.execute("insert into test_datetime values (%s, %s)", (1, dt)) c.execute("select ts from test_datetime") self.assertEqual((dt,), c.fetchone()) finally: @@ -162,7 +195,7 @@ class TestCursor(base.PyMySQLTestCase): # compatible with the DB-API 2.0 spec and has not broken # any unit tests for anything we've tried. - #def test_description(self): + # def test_description(self): # """ test description attribute """ # # result is from MySQLdb module # r = (('Host', 254, 11, 60, 60, 0, 0), @@ -227,22 +260,22 @@ def test_aggregates(self): conn = self.connect() c = conn.cursor() try: - c.execute('create table test_aggregates (i integer)') + c.execute("create table test_aggregates (i integer)") for i in range(0, 10): - c.execute('insert into test_aggregates (i) values (%s)', (i,)) - c.execute('select sum(i) from test_aggregates') - r, = c.fetchone() - self.assertEqual(sum(range(0,10)), r) + c.execute("insert into test_aggregates (i) values (%s)", (i,)) + c.execute("select sum(i) from test_aggregates") + (r,) = c.fetchone() + self.assertEqual(sum(range(0, 10)), r) finally: - c.execute('drop table test_aggregates') + c.execute("drop table test_aggregates") def test_single_tuple(self): """ test a single tuple """ conn = self.connect() c = conn.cursor() self.safe_create_table( - conn, 'mystuff', - "create table mystuff (id integer primary key)") + conn, "mystuff", "create table mystuff (id integer primary key)" + ) c.execute("insert into mystuff (id) values (1)") c.execute("insert into mystuff (id) values (2)") c.execute("select id from mystuff where id in %s", ((1,),)) @@ -256,12 +289,16 @@ def test_json(self): if not self.mysql_server_is(conn, (5, 7, 0)): pytest.skip("JSON type is not supported on MySQL <= 5.6") - self.safe_create_table(conn, "test_json", """\ + self.safe_create_table( + conn, + "test_json", + """\ create table test_json ( id int not null, json JSON not null, primary key (id) -);""") +);""", + ) cur = conn.cursor() json_str = u'{"hello": "こんにちは"}' @@ -285,7 +322,10 @@ def setUp(self): c = conn.cursor(self.cursor_type) # create a table ane some data to query - self.safe_create_table(conn, 'bulkinsert', """\ + self.safe_create_table( + conn, + "bulkinsert", + """\ CREATE TABLE bulkinsert ( id int, @@ -294,7 +334,8 @@ def setUp(self): height int, PRIMARY KEY (id) ) -""") +""", + ) def _verify_records(self, data): conn = self.connect() @@ -308,27 +349,38 @@ def test_bulk_insert(self): cursor = conn.cursor() data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)] - cursor.executemany("insert into bulkinsert (id, name, age, height) " - "values (%s,%s,%s,%s)", data) + cursor.executemany( + "insert into bulkinsert (id, name, age, height) " "values (%s,%s,%s,%s)", + data, + ) self.assertEqual( - cursor._last_executed, bytearray( - b"insert into bulkinsert (id, name, age, height) values " - b"(0,'bob',21,123),(1,'jim',56,45),(2,'fred',100,180)")) - cursor.execute('commit') + cursor._last_executed, + bytearray( + b"insert into bulkinsert (id, name, age, height) values " + b"(0,'bob',21,123),(1,'jim',56,45),(2,'fred',100,180)" + ), + ) + cursor.execute("commit") self._verify_records(data) def test_bulk_insert_multiline_statement(self): conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)] - cursor.executemany("""insert + cursor.executemany( + """insert into bulkinsert (id, name, age, height) values (%s, %s , %s, %s ) - """, data) - self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert + """, + data, + ) + self.assertEqual( + cursor._last_executed.strip(), + bytearray( + b"""insert into bulkinsert (id, name, age, height) values (0, @@ -337,17 +389,21 @@ def test_bulk_insert_multiline_statement(self): 'jim' , 56, 45 ),(2, 'fred' , 100, -180 )""")) - cursor.execute('commit') +180 )""" + ), + ) + cursor.execute("commit") self._verify_records(data) def test_bulk_insert_single_record(self): conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123)] - cursor.executemany("insert into bulkinsert (id, name, age, height) " - "values (%s,%s,%s,%s)", data) - cursor.execute('commit') + cursor.executemany( + "insert into bulkinsert (id, name, age, height) " "values (%s,%s,%s,%s)", + data, + ) + cursor.execute("commit") self._verify_records(data) def test_issue_288(self): @@ -355,15 +411,21 @@ def test_issue_288(self): conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)] - cursor.executemany("""insert + cursor.executemany( + """insert into bulkinsert (id, name, age, height) values (%s, %s , %s, %s ) on duplicate key update age = values(age) - """, data) - self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert + """, + data, + ) + self.assertEqual( + cursor._last_executed.strip(), + bytearray( + b"""insert into bulkinsert (id, name, age, height) values (0, @@ -373,6 +435,8 @@ def test_issue_288(self): 45 ),(2, 'fred' , 100, 180 ) on duplicate key update -age = values(age)""")) - cursor.execute('commit') +age = values(age)""" + ), + ) + cursor.execute("commit") self._verify_records(data) diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index db36c3e6..abd30e0b 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -54,34 +54,37 @@ class TestAuthentication(base.PyMySQLTestCase): sha256_password_found = False import os - osuser = os.environ.get('USER') + + osuser = os.environ.get("USER") # socket auth requires the current user and for the connection to be a socket # rest do grants @localhost due to incomplete logic - TODO change to @% then db = base.PyMySQLTestCase.databases[0].copy() - socket_auth = db.get('unix_socket') is not None \ - and db.get('host') in ('localhost', '127.0.0.1') + socket_auth = db.get("unix_socket") is not None and db.get("host") in ( + "localhost", + "127.0.0.1", + ) cur = pymysql.connect(**db).cursor() - del db['user'] + del db["user"] cur.execute("SHOW PLUGINS") for r in cur: - if (r[1], r[2]) != (u'ACTIVE', u'AUTHENTICATION'): + if (r[1], r[2]) != (u"ACTIVE", u"AUTHENTICATION"): continue - if r[3] == u'auth_socket.so' or r[0] == u'unix_socket': + if r[3] == u"auth_socket.so" or r[0] == u"unix_socket": socket_plugin_name = r[0] socket_found = True - elif r[3] == u'dialog_examples.so': - if r[0] == 'two_questions': - two_questions_found = True - elif r[0] == 'three_attempts': - three_attempts_found = True - elif r[0] == u'pam': + elif r[3] == u"dialog_examples.so": + if r[0] == "two_questions": + two_questions_found = True + elif r[0] == "three_attempts": + three_attempts_found = True + elif r[0] == u"pam": pam_found = True - pam_plugin_name = r[3].split('.')[0] - if pam_plugin_name == 'auth_pam': - pam_plugin_name = 'pam' + pam_plugin_name = r[3].split(".")[0] + if pam_plugin_name == "auth_pam": + pam_plugin_name = "pam" # MySQL: authentication_pam # https://dev.mysql.com/doc/refman/5.5/en/pam-authentication-plugin.html @@ -89,11 +92,11 @@ class TestAuthentication(base.PyMySQLTestCase): # https://mariadb.com/kb/en/mariadb/pam-authentication-plugin/ # Names differ but functionality is close - elif r[0] == u'mysql_old_password': + elif r[0] == u"mysql_old_password": mysql_old_password_found = True - elif r[0] == u'sha256_password': + elif r[0] == u"sha256_password": sha256_password_found = True - #else: + # else: # print("plugin: %r" % r[0]) def test_plugin(self): @@ -101,9 +104,11 @@ def test_plugin(self): if not self.mysql_server_is(conn, (5, 5, 0)): pytest.skip("MySQL-5.5 required for plugins") cur = conn.cursor() - cur.execute("select plugin from mysql.user where concat(user, '@', host)=current_user()") + cur.execute( + "select plugin from mysql.user where concat(user, '@', host)=current_user()" + ) for r in cur: - self.assertIn(conn._auth_plugin_name, (r[0], 'mysql_native_password')) + self.assertIn(conn._auth_plugin_name, (r[0], "mysql_native_password")) @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") @pytest.mark.skipif(socket_found, reason="socket plugin already installed") @@ -113,17 +118,17 @@ def testSocketAuthInstallPlugin(self): try: cur.execute("install plugin auth_socket soname 'auth_socket.so'") TestAuthentication.socket_found = True - self.socket_plugin_name = 'auth_socket' + self.socket_plugin_name = "auth_socket" self.realtestSocketAuth() except pymysql.err.InternalError: try: cur.execute("install soname 'auth_socket'") TestAuthentication.socket_found = True - self.socket_plugin_name = 'unix_socket' + self.socket_plugin_name = "unix_socket" self.realtestSocketAuth() except pymysql.err.InternalError: TestAuthentication.socket_found = False - pytest.skip('we couldn\'t install the socket plugin') + pytest.skip("we couldn't install the socket plugin") finally: if TestAuthentication.socket_found: cur.execute("uninstall plugin %s" % self.socket_plugin_name) @@ -134,27 +139,30 @@ def testSocketAuth(self): self.realtestSocketAuth() def realtestSocketAuth(self): - with TempUser(self.connect().cursor(), TestAuthentication.osuser + '@localhost', - self.databases[0]['db'], self.socket_plugin_name) as u: + with TempUser( + self.connect().cursor(), + TestAuthentication.osuser + "@localhost", + self.databases[0]["db"], + self.socket_plugin_name, + ) as u: c = pymysql.connect(user=TestAuthentication.osuser, **self.db) class Dialog: - fail=False + fail = False def __init__(self, con): - self.fail=TestAuthentication.Dialog.fail + self.fail = TestAuthentication.Dialog.fail pass def prompt(self, echo, prompt): if self.fail: - self.fail=False - return b'bad guess at a password' + self.fail = False + return b"bad guess at a password" return self.m.get(prompt) class DialogHandler: - def __init__(self, con): - self.con=con + self.con = con def authenticate(self, pkt): while True: @@ -163,10 +171,10 @@ def authenticate(self, pkt): last = (flag & 0x01) == 0x01 prompt = pkt.read_all() - if prompt == b'Password, please:': - self.con.write_packet(b'stillnotverysecret\0') + if prompt == b"Password, please:": + self.con.write_packet(b"stillnotverysecret\0") else: - self.con.write_packet(b'no idea what to do with this prompt\0') + self.con.write_packet(b"no idea what to do with this prompt\0") pkt = self.con._read_packet() pkt.check_error() if pkt.is_ok_packet() or last: @@ -175,11 +183,12 @@ def authenticate(self, pkt): class DefectiveHandler: def __init__(self, con): - self.con=con - + self.con = con @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") - @pytest.mark.skipif(two_questions_found, reason="two_questions plugin already installed") + @pytest.mark.skipif( + two_questions_found, reason="two_questions plugin already installed" + ) def testDialogAuthTwoQuestionsInstallPlugin(self): # needs plugin. lets install it. cur = self.connect().cursor() @@ -188,7 +197,7 @@ def testDialogAuthTwoQuestionsInstallPlugin(self): TestAuthentication.two_questions_found = True self.realTestDialogAuthTwoQuestions() except pymysql.err.InternalError: - pytest.skip('we couldn\'t install the two_questions plugin') + pytest.skip("we couldn't install the two_questions plugin") finally: if TestAuthentication.two_questions_found: cur.execute("uninstall plugin two_questions") @@ -199,17 +208,30 @@ def testDialogAuthTwoQuestions(self): self.realTestDialogAuthTwoQuestions() def realTestDialogAuthTwoQuestions(self): - TestAuthentication.Dialog.fail=False - TestAuthentication.Dialog.m = {b'Password, please:': b'notverysecret', - b'Are you sure ?': b'yes, of course'} - with TempUser(self.connect().cursor(), 'pymysql_2q@localhost', - self.databases[0]['db'], 'two_questions', 'notverysecret') as u: + TestAuthentication.Dialog.fail = False + TestAuthentication.Dialog.m = { + b"Password, please:": b"notverysecret", + b"Are you sure ?": b"yes, of course", + } + with TempUser( + self.connect().cursor(), + "pymysql_2q@localhost", + self.databases[0]["db"], + "two_questions", + "notverysecret", + ) as u: with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_2q', **self.db) - pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) + pymysql.connect(user="pymysql_2q", **self.db) + pymysql.connect( + user="pymysql_2q", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db + ) @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") - @pytest.mark.skipif(three_attempts_found, reason="three_attempts plugin already installed") + @pytest.mark.skipif( + three_attempts_found, reason="three_attempts plugin already installed" + ) def testDialogAuthThreeAttemptsQuestionsInstallPlugin(self): # needs plugin. lets install it. cur = self.connect().cursor() @@ -218,7 +240,7 @@ def testDialogAuthThreeAttemptsQuestionsInstallPlugin(self): TestAuthentication.three_attempts_found = True self.realTestDialogAuthThreeAttempts() except pymysql.err.InternalError: - pytest.skip('we couldn\'t install the three_attempts plugin') + pytest.skip("we couldn't install the three_attempts plugin") finally: if TestAuthentication.three_attempts_found: cur.execute("uninstall plugin three_attempts") @@ -229,30 +251,67 @@ def testDialogAuthThreeAttempts(self): self.realTestDialogAuthThreeAttempts() def realTestDialogAuthThreeAttempts(self): - TestAuthentication.Dialog.m = {b'Password, please:': b'stillnotverysecret'} - TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all - with TempUser(self.connect().cursor(), 'pymysql_3a@localhost', - self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u: - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db) + TestAuthentication.Dialog.m = {b"Password, please:": b"stillnotverysecret"} + TestAuthentication.Dialog.fail = ( + True # fail just once. We've got three attempts after all + ) + with TempUser( + self.connect().cursor(), + "pymysql_3a@localhost", + self.databases[0]["db"], + "three_attempts", + "stillnotverysecret", + ) as u: + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db + ) + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.DialogHandler}, + **self.db + ) with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db) + pymysql.connect( + user="pymysql_3a", auth_plugin_map={b"dialog": object}, **self.db + ) with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db) + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.DefectiveHandler}, + **self.db + ) with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db) - TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'} + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"notdialogplugin": TestAuthentication.Dialog}, + **self.db + ) + TestAuthentication.Dialog.m = {b"Password, please:": b"I do not know"} with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) - TestAuthentication.Dialog.m = {b'Password, please:': None} + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db + ) + TestAuthentication.Dialog.m = {b"Password, please:": None} with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db + ) @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") @pytest.mark.skipif(pam_found, reason="pam plugin already installed") - @pytest.mark.skipif(os.environ.get('PASSWORD') is None, reason="PASSWORD env var required") - @pytest.mark.skipif(os.environ.get('PAMSERVICE') is None, reason="PAMSERVICE env var required") + @pytest.mark.skipif( + os.environ.get("PASSWORD") is None, reason="PASSWORD env var required" + ) + @pytest.mark.skipif( + os.environ.get("PAMSERVICE") is None, reason="PAMSERVICE env var required" + ) def testPamAuthInstallPlugin(self): # needs plugin. lets install it. cur = self.connect().cursor() @@ -261,133 +320,162 @@ def testPamAuthInstallPlugin(self): TestAuthentication.pam_found = True self.realTestPamAuth() except pymysql.err.InternalError: - pytest.skip('we couldn\'t install the auth_pam plugin') + pytest.skip("we couldn't install the auth_pam plugin") finally: if TestAuthentication.pam_found: cur.execute("uninstall plugin pam") - @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") @pytest.mark.skipif(not pam_found, reason="no pam plugin") - @pytest.mark.skipif(os.environ.get('PASSWORD') is None, reason="PASSWORD env var required") - @pytest.mark.skipif(os.environ.get('PAMSERVICE') is None, reason="PAMSERVICE env var required") + @pytest.mark.skipif( + os.environ.get("PASSWORD") is None, reason="PASSWORD env var required" + ) + @pytest.mark.skipif( + os.environ.get("PAMSERVICE") is None, reason="PAMSERVICE env var required" + ) def testPamAuth(self): self.realTestPamAuth() def realTestPamAuth(self): db = self.db.copy() import os - db['password'] = os.environ.get('PASSWORD') + + db["password"] = os.environ.get("PASSWORD") cur = self.connect().cursor() try: - cur.execute('show grants for ' + TestAuthentication.osuser + '@localhost') + cur.execute("show grants for " + TestAuthentication.osuser + "@localhost") grants = cur.fetchone()[0] - cur.execute('drop user ' + TestAuthentication.osuser + '@localhost') + cur.execute("drop user " + TestAuthentication.osuser + "@localhost") except pymysql.OperationalError as e: # assuming the user doesn't exist which is ok too self.assertEqual(1045, e.args[0]) grants = None - with TempUser(cur, TestAuthentication.osuser + '@localhost', - self.databases[0]['db'], 'pam', os.environ.get('PAMSERVICE')) as u: + with TempUser( + cur, + TestAuthentication.osuser + "@localhost", + self.databases[0]["db"], + "pam", + os.environ.get("PAMSERVICE"), + ) as u: try: c = pymysql.connect(user=TestAuthentication.osuser, **db) - db['password'] = 'very bad guess at password' + db["password"] = "very bad guess at password" with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user=TestAuthentication.osuser, - auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, - **self.db) + pymysql.connect( + user=TestAuthentication.osuser, + auth_plugin_map={ + b"mysql_cleartext_password": TestAuthentication.DefectiveHandler + }, + **self.db + ) except pymysql.OperationalError as e: self.assertEqual(1045, e.args[0]) # we had 'bad guess at password' work with pam. Well at least we get a permission denied here with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user=TestAuthentication.osuser, - auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, - **self.db) + pymysql.connect( + user=TestAuthentication.osuser, + auth_plugin_map={ + b"mysql_cleartext_password": TestAuthentication.DefectiveHandler + }, + **self.db + ) if grants: # recreate the user cur.execute(grants) # select old_password("crummy p\tassword"); - #| old_password("crummy p\tassword") | - #| 2a01785203b08770 | + # | old_password("crummy p\tassword") | + # | 2a01785203b08770 | @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") - @pytest.mark.skipif(not mysql_old_password_found, reason="no mysql_old_password plugin") + @pytest.mark.skipif( + not mysql_old_password_found, reason="no mysql_old_password plugin" + ) def testMySQLOldPasswordAuth(self): conn = self.connect() if self.mysql_server_is(conn, (5, 7, 0)): - pytest.skip('Old passwords aren\'t supported in 5.7') + pytest.skip("Old passwords aren't supported in 5.7") # pymysql.err.OperationalError: (1045, "Access denied for user 'old_pass_user'@'localhost' (using password: YES)") # from login in MySQL-5.6 if self.mysql_server_is(conn, (5, 6, 0)): - pytest.skip('Old passwords don\'t authenticate in 5.6') + pytest.skip("Old passwords don't authenticate in 5.6") db = self.db.copy() - db['password'] = "crummy p\tassword" + db["password"] = "crummy p\tassword" c = conn.cursor() # deprecated in 5.6 - if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(conn, (5, 6, 0)): + if sys.version_info[0:2] >= (3, 2) and self.mysql_server_is(conn, (5, 6, 0)): with self.assertWarns(pymysql.err.Warning) as cm: - c.execute("SELECT OLD_PASSWORD('%s')" % db['password']) + c.execute("SELECT OLD_PASSWORD('%s')" % db["password"]) else: - c.execute("SELECT OLD_PASSWORD('%s')" % db['password']) + c.execute("SELECT OLD_PASSWORD('%s')" % db["password"]) v = c.fetchone()[0] - self.assertEqual(v, '2a01785203b08770') + self.assertEqual(v, "2a01785203b08770") # only works in MariaDB and MySQL-5.6 - can't separate out by version - #if self.mysql_server_is(self.connect(), (5, 5, 0)): + # if self.mysql_server_is(self.connect(), (5, 5, 0)): # with TempUser(c, 'old_pass_user@localhost', # self.databases[0]['db'], 'mysql_old_password', '2a01785203b08770') as u: # cur = pymysql.connect(user='old_pass_user', **db).cursor() # cur.execute("SELECT VERSION()") c.execute("SELECT @@secure_auth") secure_auth_setting = c.fetchone()[0] - c.execute('set old_passwords=1') + c.execute("set old_passwords=1") # pymysql.err.Warning: 'pre-4.1 password hash' is deprecated and will be removed in a future release. Please use post-4.1 password hash instead - if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(conn, (5, 6, 0)): + if sys.version_info[0:2] >= (3, 2) and self.mysql_server_is(conn, (5, 6, 0)): with self.assertWarns(pymysql.err.Warning) as cm: - c.execute('set global secure_auth=0') + c.execute("set global secure_auth=0") else: - c.execute('set global secure_auth=0') - with TempUser(c, 'old_pass_user@localhost', - self.databases[0]['db'], password=db['password']) as u: - cur = pymysql.connect(user='old_pass_user', **db).cursor() + c.execute("set global secure_auth=0") + with TempUser( + c, + "old_pass_user@localhost", + self.databases[0]["db"], + password=db["password"], + ) as u: + cur = pymysql.connect(user="old_pass_user", **db).cursor() cur.execute("SELECT VERSION()") - c.execute('set global secure_auth=%r' % secure_auth_setting) + c.execute("set global secure_auth=%r" % secure_auth_setting) @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") - @pytest.mark.skipif(not sha256_password_found, reason="no sha256 password authentication plugin found") + @pytest.mark.skipif( + not sha256_password_found, + reason="no sha256 password authentication plugin found", + ) def testAuthSHA256(self): conn = self.connect() c = conn.cursor() - with TempUser(c, 'pymysql_sha256@localhost', - self.databases[0]['db'], 'sha256_password') as u: + with TempUser( + c, "pymysql_sha256@localhost", self.databases[0]["db"], "sha256_password" + ) as u: if self.mysql_server_is(conn, (5, 7, 0)): c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'") else: - c.execute('SET old_passwords = 2') - c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')") + c.execute("SET old_passwords = 2") + c.execute( + "SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')" + ) c.execute("FLUSH PRIVILEGES") db = self.db.copy() - db['password'] = "Sh@256Pa33" - # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test. + db["password"] = "Sh@256Pa33" + # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test. with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_sha256', **db) + pymysql.connect(user="pymysql_sha256", **db) -class TestConnection(base.PyMySQLTestCase): +class TestConnection(base.PyMySQLTestCase): def test_utf8mb4(self): """This test requires MySQL >= 5.5""" arg = self.databases[0].copy() - arg['charset'] = 'utf8mb4' + arg["charset"] = "utf8mb4" conn = pymysql.connect(**arg) def test_largedata(self): """Large query and response (>=16MB)""" cur = self.connect().cursor() cur.execute("SELECT @@max_allowed_packet") - if cur.fetchone()[0] < 16*1024*1024 + 10: + if cur.fetchone()[0] < 16 * 1024 * 1024 + 10: print("Set max_allowed_packet to bigger than 17MB") return - t = 'a' * (16*1024*1024) + t = "a" * (16 * 1024 * 1024) cur.execute("SELECT '" + t + "'") assert cur.fetchone()[0] == t @@ -406,15 +494,15 @@ def test_autocommit(self): def test_select_db(self): con = self.connect() - current_db = self.databases[0]['db'] - other_db = self.databases[1]['db'] + current_db = self.databases[0]["db"] + other_db = self.databases[1]["db"] cur = con.cursor() - cur.execute('SELECT database()') + cur.execute("SELECT database()") self.assertEqual(cur.fetchone()[0], current_db) con.select_db(other_db) - cur.execute('SELECT database()') + cur.execute("SELECT database()") self.assertEqual(cur.fetchone()[0], other_db) def test_connection_gone_away(self): @@ -429,29 +517,30 @@ def test_connection_gone_away(self): with self.assertRaises(pymysql.OperationalError) as cm: cur.execute("SELECT 1+1") # error occures while reading, not writing because of socket buffer. - #self.assertEqual(cm.exception.args[0], 2006) + # self.assertEqual(cm.exception.args[0], 2006) self.assertIn(cm.exception.args[0], (2006, 2013)) def test_init_command(self): conn = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) c = conn.cursor() c.execute('select "foobar";') - self.assertEqual(('foobar',), c.fetchone()) + self.assertEqual(("foobar",), c.fetchone()) conn.close() with self.assertRaises(pymysql.err.Error): conn.ping(reconnect=False) def test_read_default_group(self): conn = self.connect( - read_default_group='client', + read_default_group="client", ) self.assertTrue(conn.open) def test_set_charset(self): c = self.connect() - c.set_charset('utf8mb4') + c.set_charset("utf8mb4") # TODO validate setting here def test_defer_connect(self): @@ -460,12 +549,13 @@ def test_defer_connect(self): d = self.databases[0].copy() try: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(d['unix_socket']) + sock.connect(d["unix_socket"]) except KeyError: sock.close() sock = socket.create_connection( - (d.get('host', 'localhost'), d.get('port', 3306))) - for k in ['unix_socket', 'host', 'port']: + (d.get("host", "localhost"), d.get("port", 3306)) + ) + for k in ["unix_socket", "host", "port"]: try: del d[k] except KeyError: @@ -479,9 +569,12 @@ def test_defer_connect(self): def test_ssl_connect(self): dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl={ "ca": "ca", @@ -497,9 +590,12 @@ def test_ssl_connect(self): dummy_ssl_context.set_ciphers.assert_called_with("cipher") dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl={ "ca": "ca", @@ -514,9 +610,12 @@ def test_ssl_connect(self): dummy_ssl_context.set_ciphers.assert_not_called dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_ca="ca", ) @@ -527,9 +626,12 @@ def test_ssl_connect(self): dummy_ssl_context.set_ciphers.assert_not_called dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_ca="ca", ssl_cert="cert", @@ -543,9 +645,12 @@ def test_ssl_connect(self): for ssl_verify_cert in (True, "1", "yes", "true"): dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_cert="cert", ssl_key="key", @@ -554,14 +659,19 @@ def test_ssl_connect(self): assert create_default_context.called assert not dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", keyfile="key" + ) dummy_ssl_context.set_ciphers.assert_not_called for ssl_verify_cert in (None, False, "0", "no", "false"): dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_cert="cert", ssl_key="key", @@ -570,15 +680,20 @@ def test_ssl_connect(self): assert create_default_context.called assert not dummy_ssl_context.check_hostname assert dummy_ssl_context.verify_mode == ssl.CERT_NONE - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", keyfile="key" + ) dummy_ssl_context.set_ciphers.assert_not_called for ssl_ca in ("ca", None): for ssl_verify_cert in ("foo", "bar", ""): dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_ca=ssl_ca, ssl_cert="cert", @@ -587,14 +702,21 @@ def test_ssl_connect(self): ) assert create_default_context.called assert not dummy_ssl_context.check_hostname - assert dummy_ssl_context.verify_mode == (ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE), (ssl_ca, ssl_verify_cert) - dummy_ssl_context.load_cert_chain.assert_called_with("cert", keyfile="key") + assert dummy_ssl_context.verify_mode == ( + ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE + ), (ssl_ca, ssl_verify_cert) + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", keyfile="key" + ) dummy_ssl_context.set_ciphers.assert_not_called dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_ca="ca", ssl_cert="cert", @@ -608,9 +730,12 @@ def test_ssl_connect(self): dummy_ssl_context.set_ciphers.assert_not_called dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_disabled=True, ssl={ @@ -622,9 +747,12 @@ def test_ssl_connect(self): assert not create_default_context.called dummy_ssl_context = mock.Mock(options=0) - with mock.patch("pymysql.connections.Connection.connect") as connect, \ - mock.patch("pymysql.connections.ssl.create_default_context", - new=mock.Mock(return_value=dummy_ssl_context)) as create_default_context: + with mock.patch( + "pymysql.connections.Connection.connect" + ) as connect, mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: pymysql.connect( ssl_disabled=True, ssl_ca="ca", @@ -679,7 +807,7 @@ class Custom(str): pass mapping = {str: pymysql.escape_string} - self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'") + self.assertEqual(con.escape(Custom("foobar"), mapping), "'foobar'") def test_escape_no_default(self): con = self.connect() @@ -693,7 +821,7 @@ def test_escape_dict_value(self): mapping = con.encoders.copy() mapping[Foo] = escape_foo - self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"}) + self.assertEqual(con.escape({"foo": Foo()}, mapping), {"foo": "bar"}) def test_escape_list_item(self): con = self.connect() @@ -706,7 +834,8 @@ def test_escape_list_item(self): def test_previous_cursor_not_closed(self): con = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) cur1 = con.cursor() cur1.execute("SELECT 1; SELECT 2") cur2 = con.cursor() diff --git a/pymysql/tests/test_converters.py b/pymysql/tests/test_converters.py index c2c9b6bf..dc194a9e 100644 --- a/pymysql/tests/test_converters.py +++ b/pymysql/tests/test_converters.py @@ -7,34 +7,30 @@ class TestConverter(TestCase): - def test_escape_string(self): - self.assertEqual( - converters.escape_string(u"foo\nbar"), - u"foo\\nbar" - ) + self.assertEqual(converters.escape_string(u"foo\nbar"), u"foo\\nbar") def test_convert_datetime(self): expected = datetime.datetime(2007, 2, 24, 23, 6, 20) - dt = converters.convert_datetime('2007-02-24 23:06:20') + dt = converters.convert_datetime("2007-02-24 23:06:20") self.assertEqual(dt, expected) def test_convert_datetime_with_fsp(self): expected = datetime.datetime(2007, 2, 24, 23, 6, 20, 511581) - dt = converters.convert_datetime('2007-02-24 23:06:20.511581') + dt = converters.convert_datetime("2007-02-24 23:06:20.511581") self.assertEqual(dt, expected) def _test_convert_timedelta(self, with_negate=False, with_fsp=False): - d = {'hours': 789, 'minutes': 12, 'seconds': 34} - s = '%(hours)s:%(minutes)s:%(seconds)s' % d + d = {"hours": 789, "minutes": 12, "seconds": 34} + s = "%(hours)s:%(minutes)s:%(seconds)s" % d if with_fsp: - d['microseconds'] = 511581 - s += '.%(microseconds)s' % d + d["microseconds"] = 511581 + s += ".%(microseconds)s" % d expected = datetime.timedelta(**d) if with_negate: expected = -expected - s = '-' + s + s = "-" + s tdelta = converters.convert_timedelta(s) self.assertEqual(tdelta, expected) @@ -49,10 +45,10 @@ def test_convert_timedelta_with_fsp(self): def test_convert_time(self): expected = datetime.time(23, 6, 20) - time_obj = converters.convert_time('23:06:20') + time_obj = converters.convert_time("23:06:20") self.assertEqual(time_obj, expected) def test_convert_time_with_fsp(self): expected = datetime.time(23, 6, 20, 511581) - time_obj = converters.convert_time('23:06:20.511581') + time_obj = converters.convert_time("23:06:20.511581") self.assertEqual(time_obj, expected) diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py index 4c9174f5..783caf88 100644 --- a/pymysql/tests/test_cursor.py +++ b/pymysql/tests/test_cursor.py @@ -3,6 +3,7 @@ from pymysql.tests import base import pymysql.cursors + class CursorTest(base.PyMySQLTestCase): def setUp(self): super(CursorTest, self).setUp() @@ -10,12 +11,14 @@ def setUp(self): conn = self.connect() self.safe_create_table( conn, - "test", "create table test (data varchar(10))", + "test", + "create table test (data varchar(10))", ) cursor = conn.cursor() cursor.execute( "insert into test (data) values " - "('row1'), ('row2'), ('row3'), ('row4'), ('row5')") + "('row1'), ('row2'), ('row3'), ('row4'), ('row5')" + ) cursor.close() self.test_connection = pymysql.connect(**self.databases[0]) self.addCleanup(self.test_connection.close) @@ -51,55 +54,78 @@ def test_cleanup_rows_buffered(self): c2 = conn.cursor() c2.execute("select 1") - self.assertEqual( - c2.fetchone(), (1,) - ) + self.assertEqual(c2.fetchone(), (1,)) self.assertIsNone(c2.fetchone()) def test_executemany(self): conn = self.test_connection cursor = conn.cursor(pymysql.cursors.Cursor) - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)") - self.assertIsNotNone(m, 'error parse %s') - self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%s, %s)" + ) + self.assertIsNotNone(m, "error parse %s") + self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?") - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)") - self.assertIsNotNone(m, 'error parse %(name)s') - self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)" + ) + self.assertIsNotNone(m, "error parse %(name)s") + self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?") - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)") - self.assertIsNotNone(m, 'error parse %(id_name)s') - self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)" + ) + self.assertIsNotNone(m, "error parse %(id_name)s") + self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?") - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update") - self.assertIsNotNone(m, 'error parse %(id_name)s') - self.assertEqual(m.group(3), ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update" + ) + self.assertIsNotNone(m, "error parse %(id_name)s") + self.assertEqual( + m.group(3), + " ON duplicate update", + "group 3 not ON duplicate update, bug in RE_INSERT_VALUES?", + ) # https://github.com/PyMySQL/PyMySQL/pull/597 - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO bloup(foo, bar)VALUES(%s, %s)") + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO bloup(foo, bar)VALUES(%s, %s)" + ) assert m is not None # cursor._executed must bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)" # list args data = range(10) cursor.executemany("insert into test (data) values (%s)", data) - self.assertTrue(cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query') + self.assertTrue( + cursor._executed.endswith(b",(7),(8),(9)"), + "execute many with %s not in one query", + ) # dict args - data_dict = [{'data': i} for i in range(10)] + data_dict = [{"data": i} for i in range(10)] cursor.executemany("insert into test (data) values (%(data)s)", data_dict) - self.assertTrue(cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query') + self.assertTrue( + cursor._executed.endswith(b",(7),(8),(9)"), + "execute many with %(data)s not in one query", + ) # %% in column set - cursor.execute("""\ + cursor.execute( + """\ CREATE TABLE percent_test ( `A%` INTEGER, - `B%` INTEGER)""") + `B%` INTEGER)""" + ) try: q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)" self.assertIsNotNone(pymysql.cursors.RE_INSERT_VALUES.match(q)) cursor.executemany(q, [(3, 4), (5, 6)]) - self.assertTrue(cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query") + self.assertTrue( + cursor._executed.endswith(b"(3, 4),(5, 6)"), + "executemany with %% not in one query", + ) finally: cursor.execute("DROP TABLE IF EXISTS percent_test") diff --git a/pymysql/tests/test_err.py b/pymysql/tests/test_err.py index bb6a5c49..6b54c6d0 100644 --- a/pymysql/tests/test_err.py +++ b/pymysql/tests/test_err.py @@ -7,9 +7,8 @@ class TestRaiseException(unittest.TestCase): - def test_raise_mysql_exception(self): data = b"\xff\x15\x04#28000Access denied" with self.assertRaises(err.OperationalError) as cm: err.raise_mysql_exception(data) - self.assertEqual(cm.exception.args, (1045, 'Access denied')) + self.assertEqual(cm.exception.args, (1045, "Access denied")) diff --git a/pymysql/tests/test_issues.py b/pymysql/tests/test_issues.py index 2e11ddb5..95765e54 100644 --- a/pymysql/tests/test_issues.py +++ b/pymysql/tests/test_issues.py @@ -11,6 +11,7 @@ __all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"] + class TestOldIssues(base.PyMySQLTestCase): def test_issue_3(self): """ undefined methods datetime_or_None, date_or_None """ @@ -21,7 +22,10 @@ def test_issue_3(self): c.execute("drop table if exists issue3") c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)") try: - c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None)) + c.execute( + "insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", + (None, None, None, None), + ) c.execute("select d from issue3") self.assertEqual(None, c.fetchone()[0]) c.execute("select t from issue3") @@ -29,7 +33,11 @@ def test_issue_3(self): c.execute("select dt from issue3") self.assertEqual(None, c.fetchone()[0]) c.execute("select ts from issue3") - self.assertIn(type(c.fetchone()[0]), (type(None), datetime.datetime), 'expected Python type None or datetime from SQL timestamp') + self.assertIn( + type(c.fetchone()[0]), + (type(None), datetime.datetime), + "expected Python type None or datetime from SQL timestamp", + ) finally: c.execute("drop table issue3") @@ -58,7 +66,7 @@ def test_issue_6(self): """ exception: TypeError: ord() expected a character, but string of length 0 found """ # ToDo: this test requires access to db 'mysql'. kwargs = self.databases[0].copy() - kwargs['db'] = "mysql" + kwargs["db"] = "mysql" conn = pymysql.connect(**kwargs) c = conn.cursor() c.execute("select * from user") @@ -71,10 +79,12 @@ def test_issue_8(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists test") - c.execute("""CREATE TABLE `test` (`station` int NOT NULL DEFAULT '0', `dh` + c.execute( + """CREATE TABLE `test` (`station` int NOT NULL DEFAULT '0', `dh` datetime NOT NULL DEFAULT '2015-01-01 00:00:00', `echeance` int NOT NULL DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY -KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") +KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""" + ) try: self.assertEqual(0, c.execute("SELECT * FROM test")) c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)") @@ -92,7 +102,7 @@ def test_issue_13(self): try: cur.execute("create table issue13 (t text)") # ticket says 18k - size = 18*1024 + size = 18 * 1024 cur.execute("insert into issue13 (t) values (%s)", ("x" * size,)) cur.execute("select t from issue13") # use assertTrue so that obscenely huge error messages don't print @@ -110,9 +120,9 @@ def test_issue_15(self): c.execute("drop table if exists issue15") c.execute("create table issue15 (t varchar(32))") try: - c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',)) + c.execute("insert into issue15 (t) values (%s)", (u"\xe4\xf6\xfc",)) c.execute("select t from issue15") - self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0]) + self.assertEqual(u"\xe4\xf6\xfc", c.fetchone()[0]) finally: c.execute("drop table issue15") @@ -123,15 +133,21 @@ def test_issue_16(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue16") - c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))") + c.execute( + "create table issue16 (name varchar(32) primary key, email varchar(32))" + ) try: - c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')") + c.execute( + "insert into issue16 (name, email) values ('pete', 'floydophone')" + ) c.execute("select email from issue16 where name=%s", ("pete",)) self.assertEqual("floydophone", c.fetchone()[0]) finally: c.execute("drop table issue16") - @pytest.mark.skip("test_issue_17() requires a custom, legacy MySQL configuration and will not be run.") + @pytest.mark.skip( + "test_issue_17() requires a custom, legacy MySQL configuration and will not be run." + ) def test_issue_17(self): """could not connect mysql use passwod""" conn = self.connect() @@ -146,7 +162,10 @@ def test_issue_17(self): c.execute("drop table if exists issue17") c.execute("create table issue17 (x varchar(32) primary key)") c.execute("insert into issue17 (x) values ('hello, world!')") - c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db) + c.execute( + "grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" + % db + ) conn.commit() conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db) @@ -156,6 +175,7 @@ def test_issue_17(self): finally: c.execute("drop table issue17") + class TestNewIssues(base.PyMySQLTestCase): def test_issue_34(self): try: @@ -168,8 +188,9 @@ def test_issue_34(self): def test_issue_33(self): conn = pymysql.connect(charset="utf8", **self.databases[0]) - self.safe_create_table(conn, u'hei\xdfe', - u'create table hei\xdfe (name varchar(32))') + self.safe_create_table( + conn, u"hei\xdfe", u"create table hei\xdfe (name varchar(32))" + ) c = conn.cursor() c.execute(u"insert into hei\xdfe (name) values ('Pi\xdfata')") c.execute(u"select name from hei\xdfe") @@ -233,7 +254,7 @@ def test_issue_37(self): def test_issue_38(self): conn = self.connect() c = conn.cursor() - datum = "a" * 1024 * 1023 # reduced size for most default mysql installs + datum = "a" * 1024 * 1023 # reduced size for most default mysql installs try: with warnings.catch_warnings(): @@ -251,7 +272,7 @@ def disabled_test_issue_54(self): warnings.filterwarnings("ignore") c.execute("drop table if exists issue54") big_sql = "select * from issue54 where " - big_sql += " and ".join("%d=%d" % (i,i) for i in range(0, 100000)) + big_sql += " and ".join("%d=%d" % (i, i) for i in range(0, 100000)) try: c.execute("create table issue54 (id integer primary key)") @@ -261,6 +282,7 @@ def disabled_test_issue_54(self): finally: c.execute("drop table issue54") + class TestGitHubIssues(base.PyMySQLTestCase): def test_issue_66(self): """ 'Connection' object has no attribute 'insert_id' """ @@ -271,7 +293,9 @@ def test_issue_66(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue66") - c.execute("create table issue66 (id integer primary key auto_increment, x integer)") + c.execute( + "create table issue66 (id integer primary key auto_increment, x integer)" + ) c.execute("insert into issue66 (x) values (1)") c.execute("insert into issue66 (x) values (1)") self.assertEqual(2, conn.insert_id()) @@ -290,17 +314,17 @@ def test_issue_79(self): c.execute("""CREATE TABLE a (id int, value int)""") c.execute("""CREATE TABLE b (id int, value int)""") - a=(1,11) - b=(1,22) + a = (1, 11) + b = (1, 22) try: c.execute("insert into a values (%s, %s)", a) c.execute("insert into b values (%s, %s)", b) c.execute("SELECT * FROM a inner join b on a.id = b.id") r = c.fetchall()[0] - self.assertEqual(r['id'], 1) - self.assertEqual(r['value'], 11) - self.assertEqual(r['b.value'], 22) + self.assertEqual(r["id"], 1) + self.assertEqual(r["value"], 11) + self.assertEqual(r["b.value"], 22) finally: c.execute("drop table a") c.execute("drop table b") @@ -312,10 +336,12 @@ def test_issue_95(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore") cur.execute("DROP PROCEDURE IF EXISTS `foo`") - cur.execute("""CREATE PROCEDURE `foo` () + cur.execute( + """CREATE PROCEDURE `foo` () BEGIN SELECT 1; - END""") + END""" + ) try: cur.execute("""CALL foo()""") cur.execute("""SELECT 1""") @@ -355,40 +381,42 @@ def test_issue_175(self): conn = self.connect() cur = conn.cursor() for length in (200, 300): - columns = ', '.join('c{0} integer'.format(i) for i in range(length)) - sql = 'create table test_field_count ({0})'.format(columns) + columns = ", ".join("c{0} integer".format(i) for i in range(length)) + sql = "create table test_field_count ({0})".format(columns) try: cur.execute(sql) - cur.execute('select * from test_field_count') + cur.execute("select * from test_field_count") assert len(cur.description) == length finally: with warnings.catch_warnings(): warnings.filterwarnings("ignore") - cur.execute('drop table if exists test_field_count') + cur.execute("drop table if exists test_field_count") def test_issue_321(self): """ Test iterable as query argument. """ conn = pymysql.connect(charset="utf8", **self.databases[0]) self.safe_create_table( - conn, "issue321", - "create table issue321 (value_1 varchar(1), value_2 varchar(1))") + conn, + "issue321", + "create table issue321 (value_1 varchar(1), value_2 varchar(1))", + ) sql_insert = "insert into issue321 (value_1, value_2) values (%s, %s)" - sql_dict_insert = ("insert into issue321 (value_1, value_2) " - "values (%(value_1)s, %(value_2)s)") - sql_select = ("select * from issue321 where " - "value_1 in %s and value_2=%s") + sql_dict_insert = ( + "insert into issue321 (value_1, value_2) " + "values (%(value_1)s, %(value_2)s)" + ) + sql_select = "select * from issue321 where " "value_1 in %s and value_2=%s" data = [ - [(u"a", ), u"\u0430"], + [(u"a",), u"\u0430"], [[u"b"], u"\u0430"], - {"value_1": [[u"c"]], "value_2": u"\u0430"} + {"value_1": [[u"c"]], "value_2": u"\u0430"}, ] cur = conn.cursor() self.assertEqual(cur.execute(sql_insert, data[0]), 1) self.assertEqual(cur.execute(sql_insert, data[1]), 1) self.assertEqual(cur.execute(sql_dict_insert, data[2]), 1) - self.assertEqual( - cur.execute(sql_select, [(u"a", u"b", u"c"), u"\u0430"]), 3) + self.assertEqual(cur.execute(sql_select, [(u"a", u"b", u"c"), u"\u0430"]), 3) self.assertEqual(cur.fetchone(), (u"a", u"\u0430")) self.assertEqual(cur.fetchone(), (u"b", u"\u0430")) self.assertEqual(cur.fetchone(), (u"c", u"\u0430")) @@ -397,9 +425,11 @@ def test_issue_364(self): """ Test mixed unicode/binary arguments in executemany. """ conn = pymysql.connect(charset="utf8mb4", **self.databases[0]) self.safe_create_table( - conn, "issue364", + conn, + "issue364", "create table issue364 (value_1 binary(3), value_2 varchar(3)) " - "engine=InnoDB default charset=utf8mb4") + "engine=InnoDB default charset=utf8mb4", + ) sql = "insert into issue364 (value_1, value_2) values (_binary %s, %s)" usql = u"insert into issue364 (value_1, value_2) values (_binary %s, %s)" @@ -427,11 +457,13 @@ def test_issue_363(self): """ Test binary / geometry types. """ conn = pymysql.connect(charset="utf8", **self.databases[0]) self.safe_create_table( - conn, "issue363", + conn, + "issue363", "CREATE TABLE issue363 ( " "id INTEGER PRIMARY KEY, geom LINESTRING NOT NULL /*!80003 SRID 0 */, " "SPATIAL KEY geom (geom)) " - "ENGINE=MyISAM") + "ENGINE=MyISAM", + ) cur = conn.cursor() # From MySQL 5.7, ST_GeomFromText is added and GeomFromText is deprecated. @@ -443,26 +475,32 @@ def test_issue_363(self): geom_from_text = "GeomFromText" geom_as_text = "AsText" geom_as_bin = "AsBinary" - query = ("INSERT INTO issue363 (id, geom) VALUES" - "(1998, %s('LINESTRING(1.1 1.1,2.2 2.2)'))" % geom_from_text) + query = ( + "INSERT INTO issue363 (id, geom) VALUES" + "(1998, %s('LINESTRING(1.1 1.1,2.2 2.2)'))" % geom_from_text + ) cur.execute(query) # select WKT query = "SELECT %s(geom) FROM issue363" % geom_as_text cur.execute(query) row = cur.fetchone() - self.assertEqual(row, ("LINESTRING(1.1 1.1,2.2 2.2)", )) + self.assertEqual(row, ("LINESTRING(1.1 1.1,2.2 2.2)",)) # select WKB query = "SELECT %s(geom) FROM issue363" % geom_as_bin cur.execute(query) row = cur.fetchone() - self.assertEqual(row, - (b"\x01\x02\x00\x00\x00\x02\x00\x00\x00" - b"\x9a\x99\x99\x99\x99\x99\xf1?" - b"\x9a\x99\x99\x99\x99\x99\xf1?" - b"\x9a\x99\x99\x99\x99\x99\x01@" - b"\x9a\x99\x99\x99\x99\x99\x01@", )) + self.assertEqual( + row, + ( + b"\x01\x02\x00\x00\x00\x02\x00\x00\x00" + b"\x9a\x99\x99\x99\x99\x99\xf1?" + b"\x9a\x99\x99\x99\x99\x99\xf1?" + b"\x9a\x99\x99\x99\x99\x99\x01@" + b"\x9a\x99\x99\x99\x99\x99\x01@", + ), + ) # select internal binary cur.execute("SELECT geom FROM issue363") diff --git a/pymysql/tests/test_load_local.py b/pymysql/tests/test_load_local.py index 30186e3a..bb856305 100644 --- a/pymysql/tests/test_load_local.py +++ b/pymysql/tests/test_load_local.py @@ -16,8 +16,10 @@ def test_no_file(self): self.assertRaises( OperationalError, c.execute, - ("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE " - "test_load_local fields terminated by ','") + ( + "LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE " + "test_load_local fields terminated by ','" + ), ) finally: c.execute("DROP TABLE test_load_local") @@ -28,13 +30,15 @@ def test_load_file(self): conn = self.connect() c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'data', - 'load_local_data.txt') + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "data", "load_local_data.txt" + ) try: c.execute( - ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + - "test_load_local FIELDS TERMINATED BY ','").format(filename) + ( + "LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + + "test_load_local FIELDS TERMINATED BY ','" + ).format(filename) ) c.execute("SELECT COUNT(*) FROM test_load_local") self.assertEqual(22749, c.fetchone()[0]) @@ -46,13 +50,15 @@ def test_unbuffered_load_file(self): conn = self.connect() c = conn.cursor(cursors.SSCursor) c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'data', - 'load_local_data.txt') + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "data", "load_local_data.txt" + ) try: c.execute( - ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + - "test_load_local FIELDS TERMINATED BY ','").format(filename) + ( + "LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + + "test_load_local FIELDS TERMINATED BY ','" + ).format(filename) ) c.execute("SELECT COUNT(*) FROM test_load_local") self.assertEqual(22749, c.fetchone()[0]) @@ -66,4 +72,5 @@ def test_unbuffered_load_file(self): if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/test_nextset.py b/pymysql/tests/test_nextset.py index d5467b11..2679edd5 100644 --- a/pymysql/tests/test_nextset.py +++ b/pymysql/tests/test_nextset.py @@ -7,11 +7,11 @@ class TestNextset(base.PyMySQLTestCase): - def test_nextset(self): con = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) cur = con.cursor() cur.execute("SELECT 1; SELECT 2;") self.assertEqual([(1,)], list(cur)) @@ -71,14 +71,14 @@ def test_multi_cursor(self): def test_multi_statement_warnings(self): con = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) cursor = con.cursor() try: - cursor.execute('DROP TABLE IF EXISTS a; ' - 'DROP TABLE IF EXISTS b;') + cursor.execute("DROP TABLE IF EXISTS a; " "DROP TABLE IF EXISTS b;") except TypeError: self.fail() - #TODO: How about SSCursor and nextset? + # TODO: How about SSCursor and nextset? # It's very hard to implement correctly... diff --git a/pymysql/tests/test_optionfile.py b/pymysql/tests/test_optionfile.py index 81bd1fe4..39bd47c4 100644 --- a/pymysql/tests/test_optionfile.py +++ b/pymysql/tests/test_optionfile.py @@ -3,20 +3,19 @@ from pymysql.optionfile import Parser -__all__ = ['TestParser'] +__all__ = ["TestParser"] -_cfg_file = (r""" +_cfg_file = r""" [default] string = foo quoted = "bar" single_quoted = 'foobar' skip-slave-start -""") +""" class TestParser(TestCase): - def test_string(self): parser = Parser() parser.read_file(StringIO(_cfg_file)) diff --git a/pymysql/tests/thirdparty/__init__.py b/pymysql/tests/thirdparty/__init__.py index 7a613478..d5f05371 100644 --- a/pymysql/tests/thirdparty/__init__.py +++ b/pymysql/tests/thirdparty/__init__.py @@ -2,4 +2,5 @@ if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/thirdparty/test_MySQLdb/__init__.py b/pymysql/tests/thirdparty/test_MySQLdb/__init__.py index e4237c69..57c42ce7 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/__init__.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/__init__.py @@ -4,4 +4,5 @@ if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py index e261a78e..ffead0ca 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py @@ -22,7 +22,7 @@ def setUp(self): db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) self.connection = db self.cursor = db.cursor() - self.BLOBText = ''.join([chr(i) for i in range(256)] * 100); + self.BLOBText = "".join([chr(i) for i in range(256)] * 100) self.BLOBUText = "".join(chr(i) for i in range(16834)) data = bytearray(range(256)) * 16 self.BLOBBinary = self.db_module.Binary(data) @@ -32,17 +32,22 @@ def setUp(self): def tearDown(self): if self.leak_test: import gc + del self.cursor orphans = gc.collect() - self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans) + self.assertFalse( + orphans, "%d orphaned objects found after deleting cursor" % orphans + ) del self.connection orphans = gc.collect() - self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans) + self.assertFalse( + orphans, "%d orphaned objects found after deleting connection" % orphans + ) def table_exists(self, name): try: - self.cursor.execute('select * from %s where 1=0' % name) + self.cursor.execute("select * from %s where 1=0" % name) except Exception: return False else: @@ -54,7 +59,7 @@ def quote_identifier(self, ident): def new_table_name(self): i = id(self.cursor) while True: - name = self.quote_identifier('tb%08x' % i) + name = self.quote_identifier("tb%08x" % i) if not self.table_exists(name): return name i = i + 1 @@ -68,25 +73,27 @@ def create_table(self, columndefs): into the table. """ self.table = self.new_table_name() - self.cursor.execute('CREATE TABLE %s (%s) %s' % - (self.table, - ',\n'.join(columndefs), - self.create_table_extra)) + self.cursor.execute( + "CREATE TABLE %s (%s) %s" + % (self.table, ",\n".join(columndefs), self.create_table_extra) + ) def check_data_integrity(self, columndefs, generator): # insert self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows) + ] if self.debug: print(data) self.cursor.executemany(insert_statement, data) self.connection.commit() # verify - self.cursor.execute('select * from %s' % self.table) + self.cursor.execute("select * from %s" % self.table) l = self.cursor.fetchall() if self.debug: print(l) @@ -94,62 +101,74 @@ def check_data_integrity(self, columndefs, generator): try: for i in range(self.rows): for j in range(len(columndefs)): - self.assertEqual(l[i][j], generator(i,j)) + self.assertEqual(l[i][j], generator(i, j)) finally: if not self.debug: - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_transactions(self): - columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + columndefs = ("col1 INT", "col2 VARCHAR(255)") + def generator(row, col): - if col == 0: return row - else: return ('%i' % (row%10))*255 + if col == 0: + return row + else: + return ("%i" % (row % 10)) * 255 + self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) # verify self.connection.commit() - self.cursor.execute('select * from %s' % self.table) + self.cursor.execute("select * from %s" % self.table) l = self.cursor.fetchall() self.assertEqual(len(l), self.rows) for i in range(self.rows): for j in range(len(columndefs)): - self.assertEqual(l[i][j], generator(i,j)) - delete_statement = 'delete from %s where col1=%%s' % self.table + self.assertEqual(l[i][j], generator(i, j)) + delete_statement = "delete from %s where col1=%%s" % self.table self.cursor.execute(delete_statement, (0,)) - self.cursor.execute('select col1 from %s where col1=%s' % \ - (self.table, 0)) + self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0)) l = self.cursor.fetchall() self.assertFalse(l, "DELETE didn't work") self.connection.rollback() - self.cursor.execute('select col1 from %s where col1=%s' % \ - (self.table, 0)) + self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0)) l = self.cursor.fetchall() self.assertTrue(len(l) == 1, "ROLLBACK didn't work") - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_truncation(self): - columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + columndefs = ("col1 INT", "col2 VARCHAR(255)") + def generator(row, col): - if col == 0: return row - else: return ('%i' % (row%10))*((255-self.rows//2)+row) + if col == 0: + return row + else: + return ("%i" % (row % 10)) * ((255 - self.rows // 2) + row) + self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) try: - self.cursor.execute(insert_statement, (0, '0'*256)) + self.cursor.execute(insert_statement, (0, "0" * 256)) except Warning: - if self.debug: print(self.cursor.messages) + if self.debug: + print(self.cursor.messages) except self.connection.DataError: pass else: - self.fail("Over-long column did not generate warnings/exception with single insert") + self.fail( + "Over-long column did not generate warnings/exception with single insert" + ) self.connection.rollback() @@ -157,132 +176,136 @@ def generator(row, col): for i in range(self.rows): data = [] for j in range(len(columndefs)): - data.append(generator(i,j)) - self.cursor.execute(insert_statement,tuple(data)) + data.append(generator(i, j)) + self.cursor.execute(insert_statement, tuple(data)) except Warning: - if self.debug: print(self.cursor.messages) + if self.debug: + print(self.cursor.messages) except self.connection.DataError: pass else: - self.fail("Over-long columns did not generate warnings/exception with execute()") + self.fail( + "Over-long columns did not generate warnings/exception with execute()" + ) self.connection.rollback() try: - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + data = [ + [generator(i, j) for j in range(len(columndefs))] + for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) except Warning: - if self.debug: print(self.cursor.messages) + if self.debug: + print(self.cursor.messages) except self.connection.DataError: pass else: - self.fail("Over-long columns did not generate warnings/exception with executemany()") + self.fail( + "Over-long columns did not generate warnings/exception with executemany()" + ) self.connection.rollback() - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_CHAR(self): # Character data - def generator(row,col): - return ('%i' % ((row+col) % 10)) * 255 - self.check_data_integrity( - ('col1 char(255)','col2 char(255)'), - generator) + def generator(row, col): + return ("%i" % ((row + col) % 10)) * 255 + + self.check_data_integrity(("col1 char(255)", "col2 char(255)"), generator) def test_INT(self): # Number data - def generator(row,col): - return row*row - self.check_data_integrity( - ('col1 INT',), - generator) + def generator(row, col): + return row * row + + self.check_data_integrity(("col1 INT",), generator) def test_DECIMAL(self): # DECIMAL - def generator(row,col): + def generator(row, col): from decimal import Decimal + return Decimal("%d.%02d" % (row, col)) - self.check_data_integrity( - ('col1 DECIMAL(5,2)',), - generator) + + self.check_data_integrity(("col1 DECIMAL(5,2)",), generator) def test_DATE(self): ticks = time() - def generator(row,col): - return self.db_module.DateFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 DATE',), - generator) + + def generator(row, col): + return self.db_module.DateFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 DATE",), generator) def test_TIME(self): ticks = time() - def generator(row,col): - return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIME',), - generator) + + def generator(row, col): + return self.db_module.TimeFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 TIME",), generator) def test_DATETIME(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 DATETIME',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 DATETIME",), generator) def test_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 TIMESTAMP",), generator) def test_fractional_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks( + ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0 + ) + + self.check_data_integrity(("col1 TIMESTAMP",), generator) def test_LONG(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBUText # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col1 INT', 'col2 LONG'), - generator) + return self.BLOBUText # 'BLOB Text ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 LONG"), generator) def test_TEXT(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBUText[:5192] # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col1 INT', 'col2 TEXT'), - generator) + return self.BLOBUText[:5192] # 'BLOB Text ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 TEXT"), generator) def test_LONG_BYTE(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 LONG BYTE'), - generator) + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 LONG BYTE"), generator) def test_BLOB(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 BLOB'), - generator) + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 BLOB"), generator) diff --git a/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py index 1cc202e2..6766aff3 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py @@ -1,4 +1,4 @@ -''' Python DB API 2.0 driver compliance unit test suite. +""" Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -8,11 +8,11 @@ this is turning out to be a thoroughly unwholesome unit test." -- Ian Bicking -''' +""" -__rcs_id__ = '$Id$' -__version__ = '$Revision$'[11:-2] -__author__ = 'Stuart Bishop ' +__rcs_id__ = "$Id$" +__version__ = "$Revision$"[11:-2] +__author__ = "Stuart Bishop " import time import unittest @@ -63,65 +63,66 @@ # - Fix bugs in test_setoutputsize_basic and test_setinputsizes # + class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. - This implementation tests Gadfly, but the TestCase - is structured so that other self.drivers can subclass this - test case to ensure compiliance with the DB-API. It is - expected that this TestCase may be expanded in the future - if ambiguities or edge conditions are discovered. + """Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. - The 'Optional Extensions' are not yet being tested. + The 'Optional Extensions' are not yet being tested. - self.drivers should subclass this test, overriding setUp, tearDown, - self.driver, connect_args and connect_kw_args. Class specification - should be as follows: + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: - import dbapi20 - class mytest(dbapi20.DatabaseAPI20Test): - [...] + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] - Don't 'import DatabaseAPI20Test from dbapi20', or you will - confuse the unit tester - just 'import dbapi20'. - ''' + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + """ # The self.driver module. This should be the module where the 'connect' # method is to be found driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = "dbapi20test_" # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix + ddl1 = "create table %sbooze (name varchar(20))" % table_prefix + ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix + xddl1 = "drop table %sbooze" % table_prefix + xddl2 = "drop table %sbarflys" % table_prefix - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + lowerfunc = "lower" # Name of stored procedure to convert string->lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self,cursor): + def executeDDL1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self,cursor): + def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - ''' self.drivers should override this method to perform required setup - if any is necessary, such as creating the database. - ''' + """self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + """ pass def tearDown(self): - ''' self.drivers should override this method to perform required cleanup - if any is necessary, such as deleting the test database. - The default drops the tables that may be created. - ''' + """self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + """ con = self._connect() try: cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): + for ddl in (self.xddl1, self.xddl2): try: cur.execute(ddl) con.commit() @@ -134,9 +135,7 @@ def tearDown(self): def _connect(self): try: - return self.driver.connect( - *self.connect_args,**self.connect_kw_args - ) + return self.driver.connect(*self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") @@ -149,7 +148,7 @@ def test_apilevel(self): # Must exist apilevel = self.driver.apilevel # Must equal 2.0 - self.assertEqual(apilevel,'2.0') + self.assertEqual(apilevel, "2.0") except AttributeError: self.fail("Driver doesn't define apilevel") @@ -158,7 +157,7 @@ def test_threadsafety(self): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) + self.assertTrue(threadsafety in (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -167,38 +166,24 @@ def test_paramstyle(self): # Must exist paramstyle = self.driver.paramstyle # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) + self.assertTrue( + paramstyle in ("qmark", "numeric", "named", "format", "pyformat") + ) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined heirarchy. - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) + self.assertTrue(issubclass(self.driver.Warning, Exception)) + self.assertTrue(issubclass(self.driver.Error, Exception)) + self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION @@ -219,7 +204,6 @@ def test_ExceptionsAsConnectionAttributes(self): self.assertTrue(con.ProgrammingError is drv.ProgrammingError) self.assertTrue(con.NotSupportedError is drv.NotSupportedError) - def test_commit(self): con = self._connect() try: @@ -232,7 +216,7 @@ def test_rollback(self): con = self._connect() # If rollback is defined, it should either work or throw # the documented exception - if hasattr(con,'rollback'): + if hasattr(con, "rollback"): try: con.rollback() except self.driver.NotSupportedError: @@ -253,14 +237,14 @@ def test_cursor_isolation(self): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) + cur1.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], "Victoria Bitter") finally: con.close() @@ -269,31 +253,41 @@ def test_description(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.assertEqual( + cur.description, + None, + "cursor.description should be none after executing a " + "statement that can return no rows (such as DDL)", + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + len(cur.description), 1, "cursor.description describes too many columns" + ) + self.assertEqual( + len(cur.description[0]), + 7, + "cursor.description[x] tuples must have 7 elements", + ) + self.assertEqual( + cur.description[0][0].lower(), + "name", + "cursor.description[x][0] must return column name", + ) + self.assertEqual( + cur.description[0][1], + self.driver.STRING, + "cursor.description[x][1] must return column type. Got %r" + % cur.description[0][1], + ) # Make sure self.description gets reset self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.assertEqual( + cur.description, + None, + "cursor.description not being set to None when executing " + "no-result statements (eg. DDL)", + ) finally: con.close() @@ -302,47 +296,49 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount should be -1 after executing no-result " "statements", + ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number or rows inserted, or " + "set to -1 after executing an insert statement", + ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount not being reset to -1 after executing " + "no-result statements", + ) finally: con.close() - lower_func = 'lower' + lower_func = "lower" + def test_callproc(self): con = self._connect() try: cur = con.cursor() - if self.lower_func and hasattr(cur,'callproc'): - r = cur.callproc(self.lower_func,('FOO',)) - self.assertEqual(len(r),1) - self.assertEqual(r[0],'FOO') + if self.lower_func and hasattr(cur, "callproc"): + r = cur.callproc(self.lower_func, ("FOO",)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], "FOO") r = cur.fetchall() - self.assertEqual(len(r),1,'callproc produced no result set') - self.assertEqual(len(r[0]),1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0],'foo', - 'callproc produced invalid results' - ) + self.assertEqual(len(r), 1, "callproc produced no result set") + self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") + self.assertEqual(r[0][0], "foo", "callproc produced invalid results") finally: con.close() @@ -355,14 +351,14 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) + self.assertRaises(self.driver.Error, self.executeDDL1, cur) # connection.commit should raise an Error if called after connection' # closed.' - self.assertRaises(self.driver.Error,con.commit) + self.assertRaises(self.driver.Error, con.commit) # connection.close should raise an Error if called more than once - self.assertRaises(self.driver.Error,con.close) + self.assertRaises(self.driver.Error, con.close) def test_execute(self): con = self._connect() @@ -372,105 +368,99 @@ def test_execute(self): finally: con.close() - def _paraminsert(self,cur): + def _paraminsert(self, cur): self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1)) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue(cur.rowcount in (-1, 1)) - if self.driver.paramstyle == 'qmark': + if self.driver.paramstyle == "qmark": cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "numeric": cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "named": cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, + {"beer": "Cooper's"}, + ) + elif self.driver.paramstyle == "format": cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "pyformat": cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) + "insert into %sbooze values (%%(beer)s)" % self.table_prefix, + {"beer": "Cooper's"}, + ) else: - self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) + self.fail("Invalid paramstyle") + self.assertTrue(cur.rowcount in (-1, 1)) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows") + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual( + beers[0], + "Cooper's", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) + self.assertEqual( + beers[1], + "Victoria Bitter", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] - if self.driver.paramstyle == 'qmark': + largs = [("Cooper's",), ("Boag's",)] + margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] + if self.driver.paramstyle == "qmark": cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "numeric": cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "named": cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, - margs - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, margs + ) + elif self.driver.paramstyle == "format": cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "pyformat": cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), - margs - ) - else: - self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount + "insert into %sbooze values (%%(beer)s)" % (self.table_prefix), + margs, ) - cur.execute('select name from %sbooze' % self.table_prefix) + else: + self.fail("Unknown paramstyle") + self.assertTrue( + cur.rowcount in (-1, 2), + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount, + ) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] + self.assertEqual( + len(res), 2, "cursor.fetchall retrieved incorrect number of rows" + ) + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + self.assertEqual(beers[0], "Boag's", "incorrect data retrieved") + self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved") finally: con.close() @@ -481,59 +471,62 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves " "no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if no more rows available", + ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() samples = [ - 'Carlton Cold', - 'Carlton Draft', - 'Mountain Goat', - 'Redback', - 'Victoria Bitter', - 'XXXX' - ] + "Carlton Cold", + "Carlton Draft", + "Mountain Goat", + "Redback", + "Victoria Bitter", + "XXXX", + ] def _populate(self): - ''' Return a list of sql commands to setup the DB for the fetch - tests. - ''' + """Return a list of sql commands to setup the DB for the fetch + tests. + """ populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] + "insert into %sbooze values ('%s')" % (self.table_prefix, s) + for s in self.samples + ] return populate def test_fetchmany(self): @@ -542,78 +535,88 @@ def test_fetchmany(self): cur = con.cursor() # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' + self.assertEqual( + len(r), + 1, + "cursor.fetchmany retrieved incorrect number of rows, " + "default of arraysize is one.", + ) + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual( + len(r), 3, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual( + len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence after " + "results are exhausted", ) - self.assertTrue(cur.rowcount in (-1,6)) + self.assertTrue(cur.rowcount in (-1, 6)) # Same as above, using cursor.arraysize - cur.arraysize=4 - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) - - cur.arraysize=6 - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) + cur.arraysize = 4 + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual( + len(r), 4, "cursor.arraysize not being honoured by fetchmany" + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + self.assertTrue(cur.rowcount in (-1, 6)) + + cur.arraysize = 6 + cur.execute("select name from %sbooze" % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' + for i in range(0, 6): + self.assertEqual( + rows[i], + self.samples[i], + "incorrect data retrieved by cursor.fetchmany", ) - self.assertTrue(cur.rowcount in (-1,6)) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, 6)) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbarflys" % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence if " + "query retrieved no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) finally: con.close() @@ -633,36 +636,41 @@ def test_fetchall(self): # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) + self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) finally: con.close() @@ -675,74 +683,74 @@ def test_mixedfetch(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() + cur.execute("select name from %sbooze" % self.table_prefix) + rows1 = cur.fetchone() rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() + rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual( + len(rows23), 2, "fetchmany returned incorrect number of rows" + ) + self.assertEqual( + len(rows56), 2, "fetchall returned incorrect number of rows" + ) rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) + rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) + rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "incorrect data retrieved or inserted" + ) finally: con.close() - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' - raise NotImplementedError('Helper not implemented') - #sql=""" + def help_nextset_setUp(self, cur): + """Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + """ + raise NotImplementedError("Helper not implemented") + # sql=""" # create procedure deleteme as # begin # select count(*) from booze # select name from booze # end - #""" - #cur.execute(sql) + # """ + # cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' - raise NotImplementedError('Helper not implemented') - #cur.execute("drop procedure deleteme") + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" + raise NotImplementedError("Helper not implemented") + # cur.execute("drop procedure deleteme") def test_nextset(self): con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() - assert s == None,'No more return sets, should return None' + s = cur.nextset() + assert s == None, "No more return sets, should return None" finally: self.help_nextset_tearDown(cur) @@ -750,16 +758,16 @@ def test_nextset(self): con.close() def test_nextset(self): - raise NotImplementedError('Drivers need to override this test') + raise NotImplementedError("Drivers need to override this test") def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue( + hasattr(cur, "arraysize"), "cursor.arraysize must be defined" + ) finally: con.close() @@ -767,8 +775,8 @@ def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works finally: con.close() @@ -778,74 +786,70 @@ def test_setoutputsize_basic(self): try: cur = con.cursor() cur.setoutputsize(1000) - cur.setoutputsize(2000,0) - self._paraminsert(cur) # Make sure the cursor still works + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): # Real test for setoutputsize is driver dependant - raise NotImplementedError('Driver need to override this test') + raise NotImplementedError("Driver need to override this test") def test_None(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("insert into %sbooze values (NULL)" % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertEqual(r[0][0], None, "NULL value not returned as None") finally: con.close() def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + d1 = self.driver.Date(2002, 12, 25) + d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(d1),str(d2)) def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + t1 = self.driver.Time(13, 45, 30) + t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Binary(self): - b = self.driver.Binary(b'Something') - b = self.driver.Binary(b'') + b = self.driver.Binary(b"Something") + b = self.driver.Binary(b"") def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) + self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined") def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "BINARY"), "module.BINARY must be defined." + ) def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined." + ) def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined." + ) def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) + self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py index 8c1dd535..139089ab 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py @@ -4,16 +4,23 @@ from pymysql.tests import base import warnings -warnings.filterwarnings('error') +warnings.filterwarnings("error") + class test_MySQLdb(capabilities.DatabaseTest): db_module = pymysql connect_args = () connect_kwargs = base.PyMySQLTestCase.databases[0].copy() - connect_kwargs.update(dict(read_default_file='~/.my.cnf', - use_unicode=True, binary_prefix=True, - charset='utf8mb4', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + connect_kwargs.update( + dict( + read_default_file="~/.my.cnf", + use_unicode=True, + binary_prefix=True, + charset="utf8mb4", + sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL", + ) + ) leak_test = False @@ -22,64 +29,70 @@ def quote_identifier(self, ident): def test_TIME(self): from datetime import timedelta - def generator(row,col): - return timedelta(0, row*8000) - self.check_data_integrity( - ('col1 TIME',), - generator) + + def generator(row, col): + return timedelta(0, row * 8000) + + self.check_data_integrity(("col1 TIME",), generator) def test_TINYINT(self): # Number data - def generator(row,col): - v = (row*row) % 256 + def generator(row, col): + v = (row * row) % 256 if v > 127: - v = v-256 + v = v - 256 return v - self.check_data_integrity( - ('col1 TINYINT',), - generator) + + self.check_data_integrity(("col1 TINYINT",), generator) def test_stored_procedures(self): db = self.connection c = self.cursor try: - self.create_table(('pos INT', 'tree CHAR(20)')) - c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, - list(enumerate('ash birch cedar larch pine'.split()))) + self.create_table(("pos INT", "tree CHAR(20)")) + c.executemany( + "INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, + list(enumerate("ash birch cedar larch pine".split())), + ) db.commit() - c.execute(""" + c.execute( + """ CREATE PROCEDURE test_sp(IN t VARCHAR(255)) BEGIN SELECT pos FROM %s WHERE tree = t; END - """ % self.table) + """ + % self.table + ) db.commit() - c.callproc('test_sp', ('larch',)) + c.callproc("test_sp", ("larch",)) rows = c.fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], 3) c.nextset() finally: c.execute("DROP PROCEDURE IF EXISTS test_sp") - c.execute('drop table %s' % (self.table)) + c.execute("drop table %s" % (self.table)) def test_small_CHAR(self): # Character data - def generator(row,col): - i = ((row+1)*(col+1)+62)%256 - if i == 62: return '' - if i == 63: return None + def generator(row, col): + i = ((row + 1) * (col + 1) + 62) % 256 + if i == 62: + return "" + if i == 63: + return None return chr(i) - self.check_data_integrity( - ('col1 char(1)','col2 char(1)'), - generator) + + self.check_data_integrity(("col1 char(1)", "col2 char(1)"), generator) def test_bug_2671682(self): from pymysql.constants import ER + try: - self.cursor.execute("describe some_non_existent_table"); + self.cursor.execute("describe some_non_existent_table") except self.connection.ProgrammingError as msg: self.assertEqual(msg.args[0], ER.NO_SUCH_TABLE) diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py index 2c9a0600..e882c5eb 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py @@ -9,13 +9,22 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): driver = pymysql connect_args = () connect_kw_args = base.PyMySQLTestCase.databases[0].copy() - connect_kw_args.update(dict(read_default_file='~/.my.cnf', - charset='utf8', - sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + connect_kw_args.update( + dict( + read_default_file="~/.my.cnf", + charset="utf8", + sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL", + ) + ) - def test_setoutputsize(self): pass - def test_setoutputsize_basic(self): pass - def test_nextset(self): pass + def test_setoutputsize(self): + pass + + def test_setoutputsize_basic(self): + pass + + def test_nextset(self): + pass """The tests on fetchone and fetchall and rowcount bogusly test for an exception if the statement cannot return a @@ -37,36 +46,41 @@ def test_fetchall(self): # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows -## self.assertRaises(self.driver.Error,cur.fetchall) + ## self.assertRaises(self.driver.Error,cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) finally: con.close() @@ -78,39 +92,40 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows self.executeDDL1(cur) -## self.assertRaises(self.driver.Error,cur.fetchone) + ## self.assertRaises(self.driver.Error,cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves " "no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) -## self.assertRaises(self.driver.Error,cur.fetchone) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + ## self.assertRaises(self.driver.Error,cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) -## self.assertEqual(cur.fetchone(),None, -## 'cursor.fetchone should return None if no more rows available' -## ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + ## self.assertEqual(cur.fetchone(),None, + ## 'cursor.fetchone should return None if no more rows available' + ## ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() @@ -120,81 +135,86 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) -## self.assertEqual(cur.rowcount,-1, -## 'cursor.rowcount should be -1 after executing no-result ' -## 'statements' -## ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) -## self.assertTrue(cur.rowcount in (-1,1), -## 'cursor.rowcount should == number or rows inserted, or ' -## 'set to -1 after executing an insert statement' -## ) + ## self.assertEqual(cur.rowcount,-1, + ## 'cursor.rowcount should be -1 after executing no-result ' + ## 'statements' + ## ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + ## self.assertTrue(cur.rowcount in (-1,1), + ## 'cursor.rowcount should == number or rows inserted, or ' + ## 'set to -1 after executing an insert statement' + ## ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) -## self.assertEqual(cur.rowcount,-1, -## 'cursor.rowcount not being reset to -1 after executing ' -## 'no-result statements' -## ) + ## self.assertEqual(cur.rowcount,-1, + ## 'cursor.rowcount not being reset to -1 after executing ' + ## 'no-result statements' + ## ) finally: con.close() def test_callproc(self): - pass # performed in test_MySQL_capabilities - - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' - sql=""" + pass # performed in test_MySQL_capabilities + + def help_nextset_setUp(self, cur): + """Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + """ + sql = """ create procedure deleteme() begin select count(*) from %(tp)sbooze; select name from %(tp)sbooze; end - """ % dict(tp=self.table_prefix) + """ % dict( + tp=self.table_prefix + ) cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" cur.execute("drop procedure deleteme") def test_nextset(self): from warnings import warn + con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() + s = cur.nextset() if s: empty = cur.fetchall() - self.assertEqual(len(empty), 0, - "non-empty result set after other result sets") - #warn("Incompatibility: MySQL returns an empty result set for the CALL itself", + self.assertEqual( + len(empty), 0, "non-empty result set after other result sets" + ) + # warn("Incompatibility: MySQL returns an empty result set for the CALL itself", # Warning) - #assert s == None,'No more return sets, should return None' + # assert s == None,'No more return sets, should return None' finally: self.help_nextset_tearDown(cur) diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py index 747ea4b0..b8d4bb1e 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py @@ -2,6 +2,7 @@ import unittest import pymysql + _mysql = pymysql from pymysql.constants import FIELD_TYPE from pymysql.tests import base @@ -26,7 +27,7 @@ class CoreModule(unittest.TestCase): def test_NULL(self): """Should have a NULL constant.""" - self.assertEqual(_mysql.NULL, 'NULL') + self.assertEqual(_mysql.NULL, "NULL") def test_version(self): """Version information sanity.""" @@ -55,36 +56,45 @@ def tearDown(self): def test_thread_id(self): tid = self.conn.thread_id() - self.assertTrue(isinstance(tid, int), - "thread_id didn't return an integral value.") + self.assertTrue( + isinstance(tid, int), "thread_id didn't return an integral value." + ) - self.assertRaises(TypeError, self.conn.thread_id, ('evil',), - "thread_id shouldn't accept arguments.") + self.assertRaises( + TypeError, + self.conn.thread_id, + ("evil",), + "thread_id shouldn't accept arguments.", + ) def test_affected_rows(self): - self.assertEqual(self.conn.affected_rows(), 0, - "Should return 0 before we do anything.") - + self.assertEqual( + self.conn.affected_rows(), 0, "Should return 0 before we do anything." + ) - #def test_debug(self): - ## FIXME Only actually tests if you lack SUPER - #self.assertRaises(pymysql.OperationalError, - #self.conn.dump_debug_info) + # def test_debug(self): + ## FIXME Only actually tests if you lack SUPER + # self.assertRaises(pymysql.OperationalError, + # self.conn.dump_debug_info) def test_charset_name(self): - self.assertTrue(isinstance(self.conn.character_set_name(), str), - "Should return a string.") + self.assertTrue( + isinstance(self.conn.character_set_name(), str), "Should return a string." + ) def test_host_info(self): assert isinstance(self.conn.get_host_info(), str), "should return a string" def test_proto_info(self): - self.assertTrue(isinstance(self.conn.get_proto_info(), int), - "Should return an int.") + self.assertTrue( + isinstance(self.conn.get_proto_info(), int), "Should return an int." + ) def test_server_info(self): - self.assertTrue(isinstance(self.conn.get_server_info(), str), - "Should return an str.") + self.assertTrue( + isinstance(self.conn.get_server_info(), str), "Should return an str." + ) + if __name__ == "__main__": unittest.main() diff --git a/pymysql/util.py b/pymysql/util.py index 04683f83..1349ec7b 100644 --- a/pymysql/util.py +++ b/pymysql/util.py @@ -10,4 +10,3 @@ def byte2int(b): def int2byte(i): return struct.pack("!B", i) - diff --git a/tests/test_auth.py b/tests/test_auth.py index 61957655..e5e2a64e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -10,7 +10,7 @@ port = 3306 ca = os.path.expanduser("~/ca.pem") -ssl = {'ca': ca, 'check_hostname': False} +ssl = {"ca": ca, "check_hostname": False} pass_sha256 = "pass_sha256_01234567890123456789" pass_caching_sha2 = "pass_caching_sha2_01234567890123456789" @@ -27,12 +27,16 @@ def test_sha256_no_passowrd_ssl(): def test_sha256_password(): - con = pymysql.connect(user="user_sha256", password=pass_sha256, host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_sha256", password=pass_sha256, host=host, port=port, ssl=None + ) con.close() def test_sha256_password_ssl(): - con = pymysql.connect(user="user_sha256", password=pass_sha256, host=host, port=port, ssl=ssl) + con = pymysql.connect( + user="user_sha256", password=pass_sha256, host=host, port=port, ssl=ssl + ) con.close() @@ -47,20 +51,44 @@ def test_caching_sha2_no_password_ssl(): def test_caching_sha2_password(): - con = pymysql.connect(user="user_caching_sha2", password=pass_caching_sha2, host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=None, + ) con.close() # Fast path of caching sha2 - con = pymysql.connect(user="user_caching_sha2", password=pass_caching_sha2, host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=None, + ) con.query("FLUSH PRIVILEGES") con.close() def test_caching_sha2_password_ssl(): - con = pymysql.connect(user="user_caching_sha2", password=pass_caching_sha2, host=host, port=port, ssl=ssl) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=ssl, + ) con.close() # Fast path of caching sha2 - con = pymysql.connect(user="user_caching_sha2", password=pass_caching_sha2, host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=None, + ) con.query("FLUSH PRIVILEGES") con.close() diff --git a/tests/test_mariadb_auth.py b/tests/test_mariadb_auth.py index 2f336fec..b3a2719c 100644 --- a/tests/test_mariadb_auth.py +++ b/tests/test_mariadb_auth.py @@ -15,8 +15,9 @@ def test_ed25519_no_password(): def test_ed25519_password(): # nosec - con = pymysql.connect(user="user_ed25519", password="pass_ed25519", - host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_ed25519", password="pass_ed25519", host=host, port=port, ssl=None + ) con.close()