Skip to content

Commit ca4317a

Browse files
authored
Fix encoding tuple argument (PyMySQL#155)
Since Connections.encoders is broken by design. Tuple and list is escaped directly in `Connection.literal()`. Removed tuple and list from converters mapping. Fixes PyMySQL#145
1 parent e39df07 commit ca4317a

File tree

4 files changed

+51
-51
lines changed

4 files changed

+51
-51
lines changed

MySQLdb/connections.py

+41-41
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class object, used to create cursors (keyword only)
186186

187187
use_unicode = kwargs2.pop('use_unicode', use_unicode)
188188
sql_mode = kwargs2.pop('sql_mode', '')
189-
binary_prefix = kwargs2.pop('binary_prefix', False)
189+
self._binary_prefix = kwargs2.pop('binary_prefix', False)
190190

191191
client_flag = kwargs.get('client_flag', 0)
192192
client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ])
@@ -208,38 +208,28 @@ class object, used to create cursors (keyword only)
208208

209209
self._server_version = tuple([ numeric_part(n) for n in self.get_server_info().split('.')[:2] ])
210210

211+
self.encoding = 'ascii' # overriden in set_character_set()
211212
db = proxy(self)
212-
def _get_string_literal():
213-
# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
214-
def string_literal(obj, dummy=None):
215-
return db.string_literal(obj)
216-
return string_literal
217-
218-
def _get_unicode_literal():
219-
if PY2:
220-
# unicode_literal is called for only unicode object.
221-
def unicode_literal(u, dummy=None):
222-
return db.string_literal(u.encode(unicode_literal.charset))
223-
else:
224-
# unicode_literal() is called for arbitrary object.
225-
def unicode_literal(u, dummy=None):
226-
return db.string_literal(str(u).encode(unicode_literal.charset))
227-
return unicode_literal
228-
229-
def _get_bytes_literal():
230-
def bytes_literal(obj, dummy=None):
231-
return b'_binary' + db.string_literal(obj)
232-
return bytes_literal
233-
234-
def _get_string_decoder():
235-
def string_decoder(s):
236-
return s.decode(string_decoder.charset)
237-
return string_decoder
238-
239-
string_literal = _get_string_literal()
240-
self.unicode_literal = unicode_literal = _get_unicode_literal()
241-
bytes_literal = _get_bytes_literal()
242-
self.string_decoder = string_decoder = _get_string_decoder()
213+
214+
# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
215+
def string_literal(obj, dummy=None):
216+
return db.string_literal(obj)
217+
218+
if PY2:
219+
# unicode_literal is called for only unicode object.
220+
def unicode_literal(u, dummy=None):
221+
return db.string_literal(u.encode(db.encoding))
222+
else:
223+
# unicode_literal() is called for arbitrary object.
224+
def unicode_literal(u, dummy=None):
225+
return db.string_literal(str(u).encode(db.encoding))
226+
227+
def bytes_literal(obj, dummy=None):
228+
return b'_binary' + db.string_literal(obj)
229+
230+
def string_decoder(s):
231+
return s.decode(db.encoding)
232+
243233
if not charset:
244234
charset = self.character_set_name()
245235
self.set_character_set(charset)
@@ -253,12 +243,7 @@ def string_decoder(s):
253243
self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder))
254244
self.converter[FIELD_TYPE.BLOB].append((None, string_decoder))
255245

256-
if binary_prefix:
257-
self.encoders[bytes] = string_literal if PY2 else bytes_literal
258-
self.encoders[bytearray] = bytes_literal
259-
else:
260-
self.encoders[bytes] = string_literal
261-
246+
self.encoders[bytes] = string_literal
262247
self.encoders[unicode] = unicode_literal
263248
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
264249
if self._transactional:
@@ -305,6 +290,16 @@ def __exit__(self, exc, value, tb):
305290
else:
306291
self.commit()
307292

293+
def _bytes_literal(self, bs):
294+
assert isinstance(bs, (bytes, bytearray))
295+
x = self.string_literal(bs) # x is escaped and quoted bytes
296+
if self._binary_prefix:
297+
return b'_binary' + x
298+
return x
299+
300+
def _tuple_literal(self, t, d):
301+
return "(%s)" % (','.join(map(self.literal, t)))
302+
308303
def literal(self, o):
309304
"""If o is a single object, returns an SQL literal as a string.
310305
If o is a non-string sequence, the items of the sequence are
@@ -313,7 +308,14 @@ def literal(self, o):
313308
Non-standard. For internal use; do not use this in your
314309
applications.
315310
"""
316-
s = self.escape(o, self.encoders)
311+
if isinstance(o, bytearray):
312+
s = self._bytes_literal(o)
313+
elif not PY2 and isinstance(o, bytes):
314+
s = self._bytes_literal(o)
315+
elif isinstance(o, (tuple, list)):
316+
s = self._tuple_literal(o)
317+
else:
318+
s = self.escape(o, self.encoders)
317319
# Python 3(~3.4) doesn't support % operation for bytes object.
318320
# We should decode it before using %.
319321
# Decoding with ascii and surrogateescape allows convert arbitrary
@@ -360,8 +362,6 @@ def set_character_set(self, charset):
360362
raise NotSupportedError("server is too old to set charset")
361363
self.query('SET NAMES %s' % charset)
362364
self.store_result()
363-
self.string_decoder.charset = py_charset
364-
self.unicode_literal.charset = py_charset
365365
self.encoding = py_charset
366366

367367
def set_sql_mode(self, sql_mode):

MySQLdb/converters.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@
2929
Don't modify conversions if you can avoid it. Instead, make copies
3030
(with the copy() method), modify the copies, and then pass them to
3131
MySQL.connect().
32-
3332
"""
3433

35-
from _mysql import string_literal, escape_sequence, escape_dict, escape, NULL
34+
from _mysql import string_literal, escape, NULL
3635
from MySQLdb.constants import FIELD_TYPE, FLAG
3736
from MySQLdb.times import *
3837
from MySQLdb.compat import PY2, long
@@ -53,6 +52,7 @@ def Str2Set(s):
5352
return set([ i for i in s.split(',') if i ])
5453

5554
def Set2Str(s, d):
55+
# Only support ascii string. Not tested.
5656
return string_literal(','.join(s), d)
5757

5858
def Thing2Str(s, d):
@@ -97,16 +97,14 @@ def quote_tuple(t, d):
9797
long: Thing2Str,
9898
float: Float2Str,
9999
NoneType: None2NULL,
100-
tuple: quote_tuple,
101-
list: quote_tuple,
102-
dict: escape_dict,
103100
ArrayType: array2Str,
104101
bool: Bool2Str,
105102
Date: Thing2Literal,
106103
DateTimeType: DateTime2literal,
107104
DateTimeDeltaType: DateTimeDelta2literal,
108105
str: Thing2Literal, # default
109106
set: Set2Str,
107+
110108
FIELD_TYPE.TINY: int,
111109
FIELD_TYPE.SHORT: int,
112110
FIELD_TYPE.LONG: long,

MySQLdb/cursors.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -225,22 +225,22 @@ def execute(self, query, args=None):
225225
# db.literal(obj) always returns str.
226226

227227
if PY2 and isinstance(query, unicode):
228-
query = query.encode(db.unicode_literal.charset)
228+
query = query.encode(db.encoding)
229229

230230
if args is not None:
231231
if isinstance(args, dict):
232232
args = dict((key, db.literal(item)) for key, item in args.items())
233233
else:
234234
args = tuple(map(db.literal, args))
235235
if not PY2 and isinstance(query, (bytes, bytearray)):
236-
query = query.decode(db.unicode_literal.charset)
236+
query = query.decode(db.encoding)
237237
try:
238238
query = query % args
239239
except TypeError as m:
240240
self.errorhandler(self, ProgrammingError, str(m))
241241

242242
if isinstance(query, unicode):
243-
query = query.encode(db.unicode_literal.charset, 'surrogateescape')
243+
query = query.encode(db.encoding, 'surrogateescape')
244244

245245
res = None
246246
try:
@@ -353,15 +353,15 @@ def callproc(self, procname, args=()):
353353
q = "SET @_%s_%d=%s" % (procname, index,
354354
db.literal(arg))
355355
if isinstance(q, unicode):
356-
q = q.encode(db.unicode_literal.charset, 'surrogateescape')
356+
q = q.encode(db.encoding, 'surrogateescape')
357357
self._query(q)
358358
self.nextset()
359359

360360
q = "CALL %s(%s)" % (procname,
361361
','.join(['@_%s_%d' % (procname, i)
362362
for i in range(len(args))]))
363363
if isinstance(q, unicode):
364-
q = q.encode(db.unicode_literal.charset, 'surrogateescape')
364+
q = q.encode(db.encoding, 'surrogateescape')
365365
self._query(q)
366366
self._executed = q
367367
if not self._defer_warnings:

_mysql.c

+2
Original file line numberDiff line numberDiff line change
@@ -2777,12 +2777,14 @@ _mysql_methods[] = {
27772777
_mysql_escape__doc__
27782778
},
27792779
{
2780+
// deprecated.
27802781
"escape_sequence",
27812782
(PyCFunction)_mysql_escape_sequence,
27822783
METH_VARARGS,
27832784
_mysql_escape_sequence__doc__
27842785
},
27852786
{
2787+
// deprecated.
27862788
"escape_dict",
27872789
(PyCFunction)_mysql_escape_dict,
27882790
METH_VARARGS,

0 commit comments

Comments
 (0)