Skip to content

Commit 49e401b

Browse files
committed
Port executemany() implementation from PyMySQL
1 parent e76b691 commit 49e401b

File tree

3 files changed

+184
-77
lines changed

3 files changed

+184
-77
lines changed

MySQLdb/connections.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def cursor(self, cursorclass=None):
278278
return (cursorclass or self.cursorclass)(self)
279279

280280
def query(self, query):
281+
# Since _mysql releases GIL while querying, we need immutable buffer.
282+
if isinstance(query, bytearray):
283+
query = bytes(query)
281284
if self.waiter is not None:
282285
self.send_query(query)
283286
self.waiter(self.fileno())
@@ -353,6 +356,7 @@ def set_character_set(self, charset):
353356
self.store_result()
354357
self.string_decoder.charset = py_charset
355358
self.unicode_literal.charset = py_charset
359+
self.encoding = py_charset
356360

357361
def set_sql_mode(self, sql_mode):
358362
"""Set the connection sql_mode. See MySQL documentation for

MySQLdb/cursors.py

Lines changed: 106 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,34 @@
22
33
This module implements Cursors of various types for MySQLdb. By
44
default, MySQLdb uses the Cursor class.
5-
65
"""
7-
6+
from __future__ import print_function, absolute_import
7+
from functools import partial
88
import re
99
import sys
10-
PY2 = sys.version_info[0] == 2
1110

1211
from MySQLdb.compat import unicode
12+
from _mysql_exceptions import (
13+
Warning, Error, InterfaceError, DataError,
14+
DatabaseError, OperationalError, IntegrityError, InternalError,
15+
NotSupportedError, ProgrammingError)
1316

14-
restr = r"""
15-
\s
16-
values
17-
\s*
18-
(
19-
\(
20-
[^()']*
21-
(?:
22-
(?:
23-
(?:\(
24-
# ( - editor highlighting helper
25-
.*
26-
\))
27-
|
28-
'
29-
[^\\']*
30-
(?:\\.[^\\']*)*
31-
'
32-
)
33-
[^()']*
34-
)*
35-
\)
36-
)
37-
"""
3817

39-
insert_values = re.compile(restr, re.S | re.I | re.X)
18+
PY2 = sys.version_info[0] == 2
19+
if PY2:
20+
text_type = unicode
21+
else:
22+
text_type = str
23+
4024

41-
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
42-
DatabaseError, OperationalError, IntegrityError, InternalError, \
43-
NotSupportedError, ProgrammingError
25+
#: Regular expression for :meth:`Cursor.executemany`.
26+
#: executemany only suports simple bulk insert.
27+
#: You can use it to load large dataset.
28+
RE_INSERT_VALUES = re.compile(
29+
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
30+
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
31+
r"(\s*(?:ON DUPLICATE.*)?)\Z",
32+
re.IGNORECASE | re.DOTALL)
4433

4534

4635
class BaseCursor(object):
@@ -60,6 +49,12 @@ class BaseCursor(object):
6049
default number of rows fetchmany() will fetch
6150
"""
6251

52+
#: Max stetement size which :meth:`executemany` generates.
53+
#:
54+
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
55+
#: Default value of max_allowed_packet is 1048576.
56+
max_stmt_length = 64*1024
57+
6358
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
6459
DatabaseError, DataError, OperationalError, IntegrityError, \
6560
InternalError, ProgrammingError, NotSupportedError
@@ -102,6 +97,32 @@ def __exit__(self, *exc_info):
10297
del exc_info
10398
self.close()
10499

100+
def _ensure_bytes(self, x, encoding=None):
101+
if isinstance(x, text_type):
102+
x = x.encode(encoding)
103+
elif isinstance(x, (tuple, list)):
104+
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
105+
return x
106+
107+
def _escape_args(self, args, conn):
108+
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
109+
110+
if isinstance(args, (tuple, list)):
111+
if PY2:
112+
args = tuple(map(ensure_bytes, args))
113+
return tuple(conn.escape(arg) for arg in args)
114+
elif isinstance(args, dict):
115+
if PY2:
116+
args = dict((ensure_bytes(key), ensure_bytes(val)) for
117+
(key, val) in args.items())
118+
return dict((key, conn.escape(val)) for (key, val) in args.items())
119+
else:
120+
# If it's not a dictionary let's try escaping it anyways.
121+
# Worst case it will throw a Value error
122+
if PY2:
123+
args = ensure_bytes(args)
124+
return conn.escape(args)
125+
105126
def _check_executed(self):
106127
if not self._executed:
107128
self.errorhandler(self, ProgrammingError, "execute() first")
@@ -230,62 +251,70 @@ def execute(self, query, args=None):
230251
return res
231252

232253
def executemany(self, query, args):
254+
# type: (str, list) -> int
233255
"""Execute a multi-row query.
234256
235-
query -- string, query to execute on server
236-
237-
args
238-
239-
Sequence of sequences or mappings, parameters to use with
240-
query.
241-
242-
Returns long integer rows affected, if any.
257+
:param query: query to execute on server
258+
:param args: Sequence of sequences or mappings. It is used as parameter.
259+
:return: Number of rows affected, if any.
243260
244261
This method improves performance on multiple-row INSERT and
245262
REPLACE. Otherwise it is equivalent to looping over args with
246263
execute().
247264
"""
248265
del self.messages[:]
249-
db = self._get_db()
250-
if not args: return
251-
if PY2 and isinstance(query, unicode):
252-
query = query.encode(db.unicode_literal.charset)
253-
elif not PY2 and isinstance(query, bytes):
254-
query = query.decode(db.unicode_literal.charset)
255-
m = insert_values.search(query)
256-
if not m:
257-
r = 0
258-
for a in args:
259-
r = r + self.execute(query, a)
260-
return r
261-
p = m.start(1)
262-
e = m.end(1)
263-
qv = m.group(1)
264-
try:
265-
q = []
266-
for a in args:
267-
if isinstance(a, dict):
268-
q.append(qv % dict((key, db.literal(item))
269-
for key, item in a.items()))
266+
267+
if not args:
268+
return
269+
270+
m = RE_INSERT_VALUES.match(query)
271+
if m:
272+
q_prefix = m.group(1) % ()
273+
q_values = m.group(2).rstrip()
274+
q_postfix = m.group(3) or ''
275+
assert q_values[0] == '(' and q_values[-1] == ')'
276+
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
277+
self.max_stmt_length,
278+
self._get_db().encoding)
279+
280+
self.rowcount = sum(self.execute(query, arg) for arg in args)
281+
return self.rowcount
282+
283+
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
284+
conn = self._get_db()
285+
escape = self._escape_args
286+
if isinstance(prefix, text_type):
287+
prefix = prefix.encode(encoding)
288+
if PY2 and isinstance(values, text_type):
289+
values = values.encode(encoding)
290+
if isinstance(postfix, text_type):
291+
postfix = postfix.encode(encoding)
292+
sql = bytearray(prefix)
293+
args = iter(args)
294+
v = values % escape(next(args), conn)
295+
if isinstance(v, text_type):
296+
if PY2:
297+
v = v.encode(encoding)
298+
else:
299+
v = v.encode(encoding, 'surrogateescape')
300+
sql += v
301+
rows = 0
302+
for arg in args:
303+
v = values % escape(arg, conn)
304+
if isinstance(v, text_type):
305+
if PY2:
306+
v = v.encode(encoding)
270307
else:
271-
q.append(qv % tuple([db.literal(item) for item in a]))
272-
except TypeError as msg:
273-
if msg.args[0] in ("not enough arguments for format string",
274-
"not all arguments converted"):
275-
self.errorhandler(self, ProgrammingError, msg.args[0])
308+
v = v.encode(encoding, 'surrogateescape')
309+
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
310+
rows += self.execute(sql + postfix)
311+
sql = bytearray(prefix)
276312
else:
277-
self.errorhandler(self, TypeError, msg)
278-
except (SystemExit, KeyboardInterrupt):
279-
raise
280-
except:
281-
exc, value = sys.exc_info()[:2]
282-
self.errorhandler(self, exc, value)
283-
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]])
284-
if not PY2:
285-
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape')
286-
r = self._query(qs)
287-
if not self._defer_warnings: self._warning_check()
288-
return r
313+
sql += b','
314+
sql += v
315+
rows += self.execute(sql + postfix)
316+
self.rowcount = rows
317+
return rows
289318

290319
def callproc(self, procname, args=()):
291320
"""Execute stored procedure procname with args

tests/test_cursor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import py.test
2+
import MySQLdb.cursors
3+
from configdb import connection_factory
4+
5+
6+
_conns = []
7+
_tables = []
8+
9+
def connect(**kwargs):
10+
conn = connection_factory(**kwargs)
11+
_conns.append(conn)
12+
return conn
13+
14+
15+
def teardown_function(function):
16+
if _tables:
17+
c = _conns[0]
18+
cur = c.cursor()
19+
for t in _tables:
20+
cur.execute("DROP TABLE %s" % (t,))
21+
cur.close()
22+
del _tables[:]
23+
24+
for c in _conns:
25+
c.close()
26+
del _conns[:]
27+
28+
29+
def test_executemany():
30+
conn = connect()
31+
cursor = conn.cursor()
32+
33+
cursor.execute("create table test (data varchar(10))")
34+
_tables.append("test")
35+
36+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)")
37+
assert m is not None, 'error parse %s'
38+
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
39+
40+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)")
41+
assert m is not None, 'error parse %(name)s'
42+
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
43+
44+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)")
45+
assert m is not None, 'error parse %(id_name)s'
46+
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
47+
48+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update")
49+
assert m is not None, 'error parse %(id_name)s'
50+
assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?'
51+
52+
# cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
53+
# list args
54+
data = range(10)
55+
cursor.executemany("insert into test (data) values (%s)", data)
56+
assert cursor._executed.endswith(",(7),(8),(9)"), 'execute many with %s not in one query'
57+
58+
# dict args
59+
data_dict = [{'data': i} for i in range(10)]
60+
cursor.executemany("insert into test (data) values (%(data)s)", data_dict)
61+
assert cursor._executed.endswith(",(7),(8),(9)"), 'execute many with %(data)s not in one query'
62+
63+
# %% in column set
64+
cursor.execute("""\
65+
CREATE TABLE percent_test (
66+
`A%` INTEGER,
67+
`B%` INTEGER)""")
68+
try:
69+
q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)"
70+
assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None
71+
cursor.executemany(q, [(3, 4), (5, 6)])
72+
assert cursor._executed.endswith("(3, 4),(5, 6)"), "executemany with %% not in one query"
73+
finally:
74+
cursor.execute("DROP TABLE IF EXISTS percent_test")

0 commit comments

Comments
 (0)