Skip to content

_binary prefix is now optional #628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions pymysql/connections.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

from .charset import MBLENGTH, charset_by_name, charset_by_id
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
from .converters import escape_item, escape_string, through, conversions as _conv
from . import converters
from .cursors import Cursor
from .optionfile import Parser
from .util import byte2int, int2byte
@@ -44,39 +44,28 @@

_py_version = sys.version_info[:2]

if PY2:
pass
elif _py_version < (3, 6):
# See http://bugs.python.org/issue24870
_surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)]

def _fast_surrogateescape(s):
return s.decode('latin1').translate(_surrogateescape_table)
else:
def _fast_surrogateescape(s):
return s.decode('ascii', 'surrogateescape')

# socket.makefile() in Python 2 is not usable because very inefficient and
# bad behavior about timeout.
# XXX: ._socketio doesn't work under IronPython.
if _py_version == (2, 7) and not IRONPYTHON:
if PY2 and not IRONPYTHON:
# read method of file-like returned by sock.makefile() is very slow.
# So we copy io-based one from Python 3.
from ._socketio import SocketIO

def _makefile(sock, mode):
return io.BufferedReader(SocketIO(sock, mode))
elif _py_version == (2, 6):
# Python 2.6 doesn't have fast io module.
# So we make original one.
class SockFile(object):
def __init__(self, sock):
self._sock = sock

def read(self, n):
read = self._sock.recv(n)
if len(read) == n:
return read
while True:
data = self._sock.recv(n-len(read))
if not data:
return read
read += data
if len(read) == n:
return read

def _makefile(sock, mode):
assert mode == 'rb'
return SockFile(sock)
else:
# socket.makefile in Python 3 is nice.
def _makefile(sock, mode):
@@ -570,6 +559,7 @@ class Connection(object):
(if no authenticate method) for returning a string from the user. (experimental)
:param db: Alias for database. (for compatibility to MySQLdb)
:param passwd: Alias for password. (for compatibility to MySQLdb)
:param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
"""

_sock = None
@@ -586,7 +576,7 @@ def __init__(self, host=None, user=None, password="",
autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False,
auth_plugin_map={}, read_timeout=None, write_timeout=None,
bind_address=None):
bind_address=None, binary_prefix=False):
if no_delay is not None:
warnings.warn("no_delay option is deprecated", DeprecationWarning)

@@ -693,14 +683,16 @@ def _config(key, arg):
self.autocommit_mode = autocommit

if conv is None:
conv = _conv
conv = converters.conversions

# Need for MySQLdb compatibility.
self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
self.sql_mode = sql_mode
self.init_command = init_command
self.max_allowed_packet = max_allowed_packet
self._auth_plugin_map = auth_plugin_map
self._binary_prefix = binary_prefix
if defer_connect:
self._sock = None
else:
@@ -812,7 +804,12 @@ def escape(self, obj, mapping=None):
"""
if isinstance(obj, str_type):
return "'" + self.escape_string(obj) + "'"
return escape_item(obj, self.charset, mapping=mapping)
if isinstance(obj, (bytes, bytearray)):
ret = self._quote_bytes(obj)
if self._binary_prefix:
ret = "_binary" + ret
return ret
return converters.escape_item(obj, self.charset, mapping=mapping)

def literal(self, obj):
"""Alias for escape()
@@ -825,7 +822,13 @@ def escape_string(self, s):
if (self.server_status &
SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES):
return s.replace("'", "''")
return escape_string(s)
return converters.escape_string(s)

def _quote_bytes(self, s):
if (self.server_status &
SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES):
return "'%s'" % (_fast_surrogateescape(s.replace(b"'", b"''")),)
return converters.escape_bytes(s)

def cursor(self, cursor=None):
"""Create a new cursor to execute queries with"""
@@ -1510,7 +1513,7 @@ def _get_descriptions(self):
else:
encoding = None
converter = self.connection.decoders.get(field_type)
if converter is through:
if converter is converters.through:
converter = None
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
self.converters.append((encoding, converter))
13 changes: 10 additions & 3 deletions pymysql/converters.py
Original file line number Diff line number Diff line change
@@ -90,9 +90,14 @@ def escape_string(value, mapping=None):
value = value.replace('"', '\\"')
return value

def escape_bytes(value, mapping=None):
def escape_bytes_prefixed(value, mapping=None):
assert isinstance(value, (bytes, bytearray))
return b"_binary'%s'" % escape_string(value)

def escape_bytes(value, mapping=None):
assert isinstance(value, (bytes, bytearray))
return b"'%s'" % escape_string(value)

else:
escape_string = _escape_unicode

@@ -102,9 +107,12 @@ def escape_bytes(value, mapping=None):
# We can escape special chars and surrogateescape at once.
_escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)]

def escape_bytes(value, mapping=None):
def escape_bytes_prefixed(value, mapping=None):
return "_binary'%s'" % value.decode('latin1').translate(_escape_bytes_table)

def escape_bytes(value, mapping=None):
return "'%s'" % value.decode('latin1').translate(_escape_bytes_table)


def escape_unicode(value, mapping=None):
return u"'%s'" % _escape_unicode(value)
@@ -373,7 +381,6 @@ def convert_characters(connection, field, data):
set: escape_sequence,
frozenset: escape_sequence,
dict: escape_dict,
bytearray: escape_bytes,
type(None): escape_None,
datetime.date: escape_date,
datetime.datetime: escape_datetime,
4 changes: 2 additions & 2 deletions pymysql/tests/test_issues.py
Original file line number Diff line number Diff line change
@@ -415,8 +415,8 @@ def test_issue_364(self):
"create table issue364 (value_1 binary(3), value_2 varchar(3)) "
"engine=InnoDB default charset=utf8")

sql = "insert into issue364 (value_1, value_2) values (%s, %s)"
usql = u"insert into issue364 (value_1, value_2) values (%s, %s)"
sql = "insert into issue364 (value_1, value_2) values (_binary %s, %s)"
usql = u"insert into issue364 (value_1, value_2) values (_binary %s, %s)"
values = [pymysql.Binary(b"\x00\xff\x00"), u"\xe4\xf6\xfc"]

# test single insert and select
2 changes: 1 addition & 1 deletion pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ class DatabaseTest(unittest.TestCase):

db_module = None
connect_args = ()
connect_kwargs = dict(use_unicode=True, charset="utf8")
connect_kwargs = dict(use_unicode=True, charset="utf8", binary_prefix=True)
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
rows = 10
debug = False
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
connect_args = ()
connect_kwargs = base.PyMySQLTestCase.databases[0].copy()
connect_kwargs.update(dict(read_default_file='~/.my.cnf',
use_unicode=True,
use_unicode=True, binary_prefix=True,
charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))

create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"