Skip to content

Reformat with black #920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 96 additions & 29 deletions pymysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,26 @@
from .constants import FIELD_TYPE
from .converters import escape_dict, escape_sequence, escape_string
from .err import (
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError, MySQLError)
Warning,
Error,
InterfaceError,
DataError,
DatabaseError,
OperationalError,
IntegrityError,
InternalError,
NotSupportedError,
ProgrammingError,
MySQLError,
)
from .times import (
Date, Time, Timestamp,
DateFromTicks, TimeFromTicks, TimestampFromTicks)
Date,
Time,
Timestamp,
DateFromTicks,
TimeFromTicks,
TimestampFromTicks,
)


VERSION = (0, 10, 1, None)
Expand All @@ -45,7 +59,6 @@


class DBAPISet(frozenset):

def __ne__(self, other):
if isinstance(other, set):
return frozenset.__ne__(self, other)
Expand All @@ -62,18 +75,32 @@ def __hash__(self):
return frozenset.__hash__(self)


STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
FIELD_TYPE.VAR_STRING])
BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB,
FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB])
NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT,
FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG,
FIELD_TYPE.TINY, FIELD_TYPE.YEAR])
DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE])
TIME = DBAPISet([FIELD_TYPE.TIME])
STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING])
BINARY = DBAPISet(
[
FIELD_TYPE.BLOB,
FIELD_TYPE.LONG_BLOB,
FIELD_TYPE.MEDIUM_BLOB,
FIELD_TYPE.TINY_BLOB,
]
)
NUMBER = DBAPISet(
[
FIELD_TYPE.DECIMAL,
FIELD_TYPE.DOUBLE,
FIELD_TYPE.FLOAT,
FIELD_TYPE.INT24,
FIELD_TYPE.LONG,
FIELD_TYPE.LONGLONG,
FIELD_TYPE.TINY,
FIELD_TYPE.YEAR,
]
)
DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE])
TIME = DBAPISet([FIELD_TYPE.TIME])
TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
DATETIME = TIMESTAMP
ROWID = DBAPISet()
DATETIME = TIMESTAMP
ROWID = DBAPISet()


def Binary(x):
Expand All @@ -87,9 +114,12 @@ def Connect(*args, **kwargs):
more information.
"""
from .connections import Connection

return Connection(*args, **kwargs)


from . import connections as _orig_conn

if _orig_conn.Connection.__init__.__doc__ is not None:
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__
del _orig_conn
Expand All @@ -99,7 +129,8 @@ def get_client_info(): # for MySQLdb compatibility
version = VERSION
if VERSION[3] is None:
version = VERSION[:3]
return '.'.join(map(str, version))
return ".".join(map(str, version))


connect = Connection = Connect

Expand All @@ -110,9 +141,11 @@ def get_client_info(): # for MySQLdb compatibility

__version__ = get_client_info()


def thread_safe():
return True # match MySQLdb.thread_safe()


def install_as_MySQLdb():
"""
After this function is called, any application that imports MySQLdb or
Expand All @@ -122,16 +155,50 @@ def install_as_MySQLdb():


__all__ = [
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError',
'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER',
'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError',
'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect',
'connections', 'constants', 'converters', 'cursors',
'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
'paramstyle', 'threadsafety', 'version_info',

"BINARY",
"Binary",
"Connect",
"Connection",
"DATE",
"Date",
"Time",
"Timestamp",
"DateFromTicks",
"TimeFromTicks",
"TimestampFromTicks",
"DataError",
"DatabaseError",
"Error",
"FIELD_TYPE",
"IntegrityError",
"InterfaceError",
"InternalError",
"MySQLError",
"NULL",
"NUMBER",
"NotSupportedError",
"DBAPISet",
"OperationalError",
"ProgrammingError",
"ROWID",
"STRING",
"TIME",
"TIMESTAMP",
"Warning",
"apilevel",
"connect",
"connections",
"constants",
"converters",
"cursors",
"escape_dict",
"escape_sequence",
"escape_string",
"get_client_info",
"paramstyle",
"threadsafety",
"version_info",
"install_as_MySQLdb",
"NULL", "__version__",
"NULL",
"__version__",
]
45 changes: 26 additions & 19 deletions pymysql/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding

_have_cryptography = True
except ImportError:
_have_cryptography = False
Expand All @@ -22,7 +23,7 @@

DEBUG = False
SCRAMBLE_LENGTH = 20
sha1_new = partial(hashlib.new, 'sha1')
sha1_new = partial(hashlib.new, "sha1")


# mysql_native_password
Expand All @@ -32,7 +33,7 @@
def scramble_native_password(password, message):
"""Scramble used for mysql_native_password"""
if not password:
return b''
return b""

stage1 = sha1_new(password).digest()
stage2 = sha1_new(stage1).digest()
Expand All @@ -59,7 +60,6 @@ def _my_crypt(message1, message2):


class RandStruct_323:

def __init__(self, seed1, seed2):
self.max_value = 0x3FFFFFFF
self.seed1 = seed1 % self.max_value
Expand All @@ -73,8 +73,10 @@ def my_rnd(self):

def scramble_old_password(password, message):
"""Scramble for old_password"""
warnings.warn("old password (for MySQL <4.1) is used. Upgrade your password with newer auth method.\n"
"old password support will be removed in future PyMySQL version")
warnings.warn(
"old password (for MySQL <4.1) is used. Upgrade your password with newer auth method.\n"
"old password support will be removed in future PyMySQL version"
)
hash_pass = _hash_password_323(password)
hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
hash_pass_n = struct.unpack(">LL", hash_pass)
Expand All @@ -100,7 +102,7 @@ def _hash_password_323(password):
nr2 = 0x12345671

# x in py3 is numbers, p27 is chars
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
for c in [byte2int(x) for x in password if x not in (" ", "\t", 32, 9)]:
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
add = (add + c) & 0xFFFFFFFF
Expand All @@ -120,9 +122,12 @@ def _init_nacl():
global _nacl_bindings
try:
from nacl import bindings

_nacl_bindings = bindings
except ImportError:
raise RuntimeError("'pynacl' package is required for ed25519_password auth method")
raise RuntimeError(
"'pynacl' package is required for ed25519_password auth method"
)


def _scalar_clamp(s32):
Expand Down Expand Up @@ -185,7 +190,7 @@ def _xor_password(password, salt):
# See https://github.com/mysql/mysql-server/blob/7d10c82196c8e45554f27c00681474a9fb86d137/sql/auth/sha2_password.cc#L939-L945
salt = salt[:SCRAMBLE_LENGTH]
password_bytes = bytearray(password)
#salt = bytearray(salt) # for PY2 compat.
# salt = bytearray(salt) # for PY2 compat.
salt_len = len(salt)
for i in range(len(password_bytes)):
password_bytes[i] ^= salt[i % salt_len]
Expand All @@ -198,8 +203,10 @@ def sha2_rsa_encrypt(password, salt, public_key):
Used for sha256_password and caching_sha2_password.
"""
if not _have_cryptography:
raise RuntimeError("'cryptography' package is required for sha256_password or caching_sha2_password auth methods")
message = _xor_password(password + b'\0', salt)
raise RuntimeError(
"'cryptography' package is required for sha256_password or caching_sha2_password auth methods"
)
message = _xor_password(password + b"\0", salt)
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
return rsa_key.encrypt(
message,
Expand All @@ -215,7 +222,7 @@ def sha256_password_auth(conn, pkt):
if conn._secure:
if DEBUG:
print("sha256: Sending plain password")
data = conn.password + b'\0'
data = conn.password + b"\0"
return _roundtrip(conn, data)

if pkt.is_auth_switch_request():
Expand All @@ -224,20 +231,20 @@ def sha256_password_auth(conn, pkt):
# Request server public key
if DEBUG:
print("sha256: Requesting server public key")
pkt = _roundtrip(conn, b'\1')
pkt = _roundtrip(conn, b"\1")

if pkt.is_extra_auth_data():
conn.server_public_key = pkt._data[1:]
if DEBUG:
print("Received public key:\n", conn.server_public_key.decode('ascii'))
print("Received public key:\n", conn.server_public_key.decode("ascii"))

if conn.password:
if not conn.server_public_key:
raise OperationalError("Couldn't receive server's public key")

data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
else:
data = b''
data = b""

return _roundtrip(conn, data)

Expand All @@ -249,7 +256,7 @@ def scramble_caching_sha2(password, nonce):
XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
"""
if not password:
return b''
return b""

p1 = hashlib.sha256(password).digest()
p2 = hashlib.sha256(p1).digest()
Expand All @@ -265,7 +272,7 @@ def scramble_caching_sha2(password, nonce):
def caching_sha2_password_auth(conn, pkt):
# No password fast path
if not conn.password:
return _roundtrip(conn, b'')
return _roundtrip(conn, b"")

if pkt.is_auth_switch_request():
# Try from fast auth
Expand Down Expand Up @@ -305,18 +312,18 @@ def caching_sha2_password_auth(conn, pkt):
if conn._secure:
if DEBUG:
print("caching sha2: Sending plain password via secure connection")
return _roundtrip(conn, conn.password + b'\0')
return _roundtrip(conn, conn.password + b"\0")

if not conn.server_public_key:
pkt = _roundtrip(conn, b'\x02') # Request public key
pkt = _roundtrip(conn, b"\x02") # Request public key
if not pkt.is_extra_auth_data():
raise OperationalError(
"caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
)

conn.server_public_key = pkt._data[1:]
if DEBUG:
print(conn.server_public_key.decode('ascii'))
print(conn.server_public_key.decode("ascii"))

data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
pkt = _roundtrip(conn, data)
16 changes: 6 additions & 10 deletions pymysql/_socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import io
import errno

__all__ = ['SocketIO']
__all__ = ["SocketIO"]

EINTR = errno.EINTR
_blocking_errnos = (errno.EAGAIN, errno.EWOULDBLOCK)


class SocketIO(io.RawIOBase):

"""Raw I/O implementation for stream sockets.
Expand Down Expand Up @@ -85,29 +86,25 @@ def write(self, b):
raise

def readable(self):
"""True if the SocketIO is open for reading.
"""
"""True if the SocketIO is open for reading."""
if self.closed:
raise ValueError("I/O operation on closed socket.")
return self._reading

def writable(self):
"""True if the SocketIO is open for writing.
"""
"""True if the SocketIO is open for writing."""
if self.closed:
raise ValueError("I/O operation on closed socket.")
return self._writing

def seekable(self):
"""True if the SocketIO is open for seeking.
"""
"""True if the SocketIO is open for seeking."""
if self.closed:
raise ValueError("I/O operation on closed socket.")
return super().seekable()

def fileno(self):
"""Return the file descriptor of the underlying socket.
"""
"""Return the file descriptor of the underlying socket."""
self._checkClosed()
return self._sock.fileno()

Expand All @@ -131,4 +128,3 @@ def close(self):
io.RawIOBase.close(self)
self._sock._decref_socketios()
self._sock = None

Loading