2
2
3
3
This module implements Cursors of various types for MySQLdb. By
4
4
default, MySQLdb uses the Cursor class.
5
-
6
5
"""
7
-
6
+ from __future__ import print_function , absolute_import
7
+ from functools import partial
8
8
import re
9
9
import sys
10
- PY2 = sys .version_info [0 ] == 2
11
10
12
11
from MySQLdb .compat import unicode
12
+ from _mysql_exceptions import (
13
+ Warning , Error , InterfaceError , DataError ,
14
+ DatabaseError , OperationalError , IntegrityError , InternalError ,
15
+ NotSupportedError , ProgrammingError )
13
16
14
- restr = r"""
15
- \s
16
- values
17
- \s*
18
- (
19
- \(
20
- [^()']*
21
- (?:
22
- (?:
23
- (?:\(
24
- # ( - editor highlighting helper
25
- .*
26
- \))
27
- |
28
- '
29
- [^\\']*
30
- (?:\\.[^\\']*)*
31
- '
32
- )
33
- [^()']*
34
- )*
35
- \)
36
- )
37
- """
38
17
39
- insert_values = re .compile (restr , re .S | re .I | re .X )
18
+ PY2 = sys .version_info [0 ] == 2
19
+ if PY2 :
20
+ text_type = unicode
21
+ else :
22
+ text_type = str
23
+
40
24
41
- from _mysql_exceptions import Warning , Error , InterfaceError , DataError , \
42
- DatabaseError , OperationalError , IntegrityError , InternalError , \
43
- NotSupportedError , ProgrammingError
25
+ #: Regular expression for :meth:`Cursor.executemany`.
26
+ #: executemany only suports simple bulk insert.
27
+ #: You can use it to load large dataset.
28
+ RE_INSERT_VALUES = re .compile (
29
+ r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
30
+ r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
31
+ r"(\s*(?:ON DUPLICATE.*)?)\Z" ,
32
+ re .IGNORECASE | re .DOTALL )
44
33
45
34
46
35
class BaseCursor (object ):
@@ -60,6 +49,12 @@ class BaseCursor(object):
60
49
default number of rows fetchmany() will fetch
61
50
"""
62
51
52
+ #: Max stetement size which :meth:`executemany` generates.
53
+ #:
54
+ #: Max size of allowed statement is max_allowed_packet - packet_header_size.
55
+ #: Default value of max_allowed_packet is 1048576.
56
+ max_stmt_length = 64 * 1024
57
+
63
58
from _mysql_exceptions import MySQLError , Warning , Error , InterfaceError , \
64
59
DatabaseError , DataError , OperationalError , IntegrityError , \
65
60
InternalError , ProgrammingError , NotSupportedError
@@ -102,6 +97,32 @@ def __exit__(self, *exc_info):
102
97
del exc_info
103
98
self .close ()
104
99
100
+ def _ensure_bytes (self , x , encoding = None ):
101
+ if isinstance (x , text_type ):
102
+ x = x .encode (encoding )
103
+ elif isinstance (x , (tuple , list )):
104
+ x = type (x )(self ._ensure_bytes (v , encoding = encoding ) for v in x )
105
+ return x
106
+
107
+ def _escape_args (self , args , conn ):
108
+ ensure_bytes = partial (self ._ensure_bytes , encoding = conn .encoding )
109
+
110
+ if isinstance (args , (tuple , list )):
111
+ if PY2 :
112
+ args = tuple (map (ensure_bytes , args ))
113
+ return tuple (conn .escape (arg ) for arg in args )
114
+ elif isinstance (args , dict ):
115
+ if PY2 :
116
+ args = dict ((ensure_bytes (key ), ensure_bytes (val )) for
117
+ (key , val ) in args .items ())
118
+ return dict ((key , conn .escape (val )) for (key , val ) in args .items ())
119
+ else :
120
+ # If it's not a dictionary let's try escaping it anyways.
121
+ # Worst case it will throw a Value error
122
+ if PY2 :
123
+ args = ensure_bytes (args )
124
+ return conn .escape (args )
125
+
105
126
def _check_executed (self ):
106
127
if not self ._executed :
107
128
self .errorhandler (self , ProgrammingError , "execute() first" )
@@ -230,62 +251,70 @@ def execute(self, query, args=None):
230
251
return res
231
252
232
253
def executemany (self , query , args ):
254
+ # type: (str, list) -> int
233
255
"""Execute a multi-row query.
234
256
235
- query -- string, query to execute on server
236
-
237
- args
238
-
239
- Sequence of sequences or mappings, parameters to use with
240
- query.
241
-
242
- Returns long integer rows affected, if any.
257
+ :param query: query to execute on server
258
+ :param args: Sequence of sequences or mappings. It is used as parameter.
259
+ :return: Number of rows affected, if any.
243
260
244
261
This method improves performance on multiple-row INSERT and
245
262
REPLACE. Otherwise it is equivalent to looping over args with
246
263
execute().
247
264
"""
248
265
del self .messages [:]
249
- db = self ._get_db ()
250
- if not args : return
251
- if PY2 and isinstance (query , unicode ):
252
- query = query .encode (db .unicode_literal .charset )
253
- elif not PY2 and isinstance (query , bytes ):
254
- query = query .decode (db .unicode_literal .charset )
255
- m = insert_values .search (query )
256
- if not m :
257
- r = 0
258
- for a in args :
259
- r = r + self .execute (query , a )
260
- return r
261
- p = m .start (1 )
262
- e = m .end (1 )
263
- qv = m .group (1 )
264
- try :
265
- q = []
266
- for a in args :
267
- if isinstance (a , dict ):
268
- q .append (qv % dict ((key , db .literal (item ))
269
- for key , item in a .items ()))
266
+
267
+ if not args :
268
+ return
269
+
270
+ m = RE_INSERT_VALUES .match (query )
271
+ if m :
272
+ q_prefix = m .group (1 ) % ()
273
+ q_values = m .group (2 ).rstrip ()
274
+ q_postfix = m .group (3 ) or ''
275
+ assert q_values [0 ] == '(' and q_values [- 1 ] == ')'
276
+ return self ._do_execute_many (q_prefix , q_values , q_postfix , args ,
277
+ self .max_stmt_length ,
278
+ self ._get_db ().encoding )
279
+
280
+ self .rowcount = sum (self .execute (query , arg ) for arg in args )
281
+ return self .rowcount
282
+
283
+ def _do_execute_many (self , prefix , values , postfix , args , max_stmt_length , encoding ):
284
+ conn = self ._get_db ()
285
+ escape = self ._escape_args
286
+ if isinstance (prefix , text_type ):
287
+ prefix = prefix .encode (encoding )
288
+ if PY2 and isinstance (values , text_type ):
289
+ values = values .encode (encoding )
290
+ if isinstance (postfix , text_type ):
291
+ postfix = postfix .encode (encoding )
292
+ sql = bytearray (prefix )
293
+ args = iter (args )
294
+ v = values % escape (next (args ), conn )
295
+ if isinstance (v , text_type ):
296
+ if PY2 :
297
+ v = v .encode (encoding )
298
+ else :
299
+ v = v .encode (encoding , 'surrogateescape' )
300
+ sql += v
301
+ rows = 0
302
+ for arg in args :
303
+ v = values % escape (arg , conn )
304
+ if isinstance (v , text_type ):
305
+ if PY2 :
306
+ v = v .encode (encoding )
270
307
else :
271
- q .append (qv % tuple ([db .literal (item ) for item in a ]))
272
- except TypeError as msg :
273
- if msg .args [0 ] in ("not enough arguments for format string" ,
274
- "not all arguments converted" ):
275
- self .errorhandler (self , ProgrammingError , msg .args [0 ])
308
+ v = v .encode (encoding , 'surrogateescape' )
309
+ if len (sql ) + len (v ) + len (postfix ) + 1 > max_stmt_length :
310
+ rows += self .execute (sql + postfix )
311
+ sql = bytearray (prefix )
276
312
else :
277
- self .errorhandler (self , TypeError , msg )
278
- except (SystemExit , KeyboardInterrupt ):
279
- raise
280
- except :
281
- exc , value = sys .exc_info ()[:2 ]
282
- self .errorhandler (self , exc , value )
283
- qs = '\n ' .join ([query [:p ], ',\n ' .join (q ), query [e :]])
284
- if not PY2 :
285
- qs = qs .encode (db .unicode_literal .charset , 'surrogateescape' )
286
- r = self ._query (qs )
287
- if not self ._defer_warnings : self ._warning_check ()
288
- return r
313
+ sql += b','
314
+ sql += v
315
+ rows += self .execute (sql + postfix )
316
+ self .rowcount = rows
317
+ return rows
289
318
290
319
def callproc (self , procname , args = ()):
291
320
"""Execute stored procedure procname with args
0 commit comments