@@ -172,19 +172,20 @@ async def close(self) -> bool:
172
172
self ._connection = None
173
173
return True
174
174
175
- def _process_params_dict (
175
+ async def _process_params_dict (
176
176
self , params : ParamsDictType
177
177
) -> Dict [bytes , Union [bytes , Decimal ]]:
178
178
"""Process query parameters given as dictionary."""
179
179
res : Dict [bytes , Any ] = {}
180
180
try :
181
+ sql_mode = await self ._connection .get_sql_mode ()
181
182
to_mysql = self ._connection .converter .to_mysql
182
183
escape = self ._connection .converter .escape
183
184
quote = self ._connection .converter .quote
184
185
for key , value in params .items ():
185
186
conv = value
186
187
conv = to_mysql (conv )
187
- conv = escape (conv )
188
+ conv = escape (conv , sql_mode )
188
189
if not isinstance (value , Decimal ):
189
190
conv = quote (conv )
190
191
res [key .encode ()] = conv
@@ -194,17 +195,18 @@ def _process_params_dict(
194
195
) from err
195
196
return res
196
197
197
- def _process_params (
198
+ async def _process_params (
198
199
self , params : ParamsSequenceType
199
200
) -> Tuple [Union [bytes , Decimal ], ...]:
200
201
"""Process query parameters."""
201
202
result = params [:]
202
203
try :
204
+ sql_mode = await self ._connection .get_sql_mode ()
203
205
to_mysql = self ._connection .converter .to_mysql
204
206
escape = self ._connection .converter .escape
205
207
quote = self ._connection .converter .quote
206
208
result = [to_mysql (value ) for value in result ]
207
- result = [escape (value ) for value in result ]
209
+ result = [escape (value , sql_mode ) for value in result ]
208
210
result = [
209
211
quote (value ) if not isinstance (params [i ], Decimal ) else value
210
212
for i , value in enumerate (result )
@@ -412,7 +414,7 @@ def _check_executed(self) -> None:
412
414
if self ._executed is None :
413
415
raise InterfaceError (ERR_NO_RESULT_TO_FETCH )
414
416
415
- def _prepare_statement (
417
+ async def _prepare_statement (
416
418
self ,
417
419
operation : StrOrBytes ,
418
420
params : Union [Sequence [Any ], Dict [str , Any ]] = (),
@@ -437,9 +439,11 @@ def _prepare_statement(
437
439
438
440
if params :
439
441
if isinstance (params , dict ):
440
- stmt = _bytestr_format_dict (stmt , self ._process_params_dict (params ))
442
+ stmt = _bytestr_format_dict (
443
+ stmt , await self ._process_params_dict (params )
444
+ )
441
445
elif isinstance (params , (list , tuple )):
442
- psub = _ParamSubstitutor (self ._process_params (params ))
446
+ psub = _ParamSubstitutor (await self ._process_params (params ))
443
447
stmt = RE_PY_PARAM .sub (psub , stmt )
444
448
if psub .remaining != 0 :
445
449
raise ProgrammingError (
@@ -461,7 +465,7 @@ async def _fetch_warnings(self) -> Optional[List[WarningType]]:
461
465
result = await cur .fetchall ()
462
466
return result if result else None # type: ignore[return-value]
463
467
464
- def _batch_insert (
468
+ async def _batch_insert (
465
469
self , operation : str , seq_params : Sequence [ParamsSequenceOrDictType ]
466
470
) -> Optional [bytes ]:
467
471
"""Implements multi row insert"""
@@ -496,9 +500,11 @@ def remove_comments(match: re.Match) -> str:
496
500
for params in seq_params :
497
501
tmp = fmt
498
502
if isinstance (params , dict ):
499
- tmp = _bytestr_format_dict (tmp , self ._process_params_dict (params ))
503
+ tmp = _bytestr_format_dict (
504
+ tmp , await self ._process_params_dict (params )
505
+ )
500
506
else :
501
- psub = _ParamSubstitutor (self ._process_params (params ))
507
+ psub = _ParamSubstitutor (await self ._process_params (params ))
502
508
tmp = RE_PY_PARAM .sub (psub , tmp )
503
509
if psub .remaining != 0 :
504
510
raise ProgrammingError (
@@ -686,7 +692,7 @@ async def execute(
686
692
await self ._connection .handle_unread_result ()
687
693
await self ._reset_result ()
688
694
689
- stmt = self ._prepare_statement (operation , params )
695
+ stmt = await self ._prepare_statement (operation , params )
690
696
self ._executed = stmt
691
697
692
698
try :
@@ -719,7 +725,7 @@ async def executemulti(
719
725
await self ._connection .handle_unread_result ()
720
726
await self ._reset_result ()
721
727
722
- stmt = self ._prepare_statement (operation , params )
728
+ stmt = await self ._prepare_statement (operation , params )
723
729
self ._executed = stmt
724
730
self ._executed_list = []
725
731
@@ -752,7 +758,7 @@ async def executemany(
752
758
if not seq_params :
753
759
self ._rowcount = 0
754
760
return None
755
- stmt = self ._batch_insert (operation , seq_params )
761
+ stmt = await self ._batch_insert (operation , seq_params )
756
762
if stmt is not None :
757
763
self ._executed = stmt
758
764
return await self .execute (stmt )
0 commit comments