Skip to content

Commit ab27bcb

Browse files
committed
BUG#23342572: Allow dictionaries as parameters in prepared statements
This patch enables the usage of dictionaries as parameters in prepared statements, using the '%(param)s' format as placeholders. Thank you for the contribution. Change-Id: Ia401ede4fd0cccc7091db3ae7e31a0d33ed9b992
1 parent dc325d1 commit ab27bcb

File tree

5 files changed

+182
-24
lines changed

5 files changed

+182
-24
lines changed

CHANGES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ v8.0.32
2828
- BUG#28020811: Fix multiple reference leaks in the C extension
2929
- BUG#27426532: Reduce callproc roundtrip time
3030
- BUG#24364556: Improve warning behavior
31+
- BUG#23342572: Allow dictionaries as parameters in prepared statements
3132

3233
v8.0.31
3334
=======

lib/mysql/connector/cursor.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
)
102102
RE_SQL_SPLIT_STMTS = re.compile(b""";(?=(?:[^"'`]*["'`].*["'`])*[^"'`]*$)""")
103103
RE_SQL_FIND_PARAM = re.compile(b"""%s(?=(?:[^"'`]*["'`][^"'`]*["'`])*[^"'`]*$)""")
104+
RE_SQL_PYTHON_REPLACE_PARAM = re.compile(r"%\(.*?\)s")
105+
RE_SQL_PYTHON_CAPTURE_PARAM_NAME = re.compile(r"%\((.*?)\)s")
104106

105107
ERR_NO_RESULT_TO_FETCH = "No result set to fetch from"
106108

@@ -1303,7 +1305,7 @@ def _handle_result(self, result: ResultType) -> None:
13031305
def execute(
13041306
self,
13051307
operation: StrOrBytes,
1306-
params: Optional[ParamsSequenceType] = None,
1308+
params: Optional[ParamsSequenceOrDictType] = None,
13071309
multi: bool = False,
13081310
) -> None: # multi is unused
13091311
"""Prepare and execute a MySQL Prepared Statement
@@ -1316,22 +1318,40 @@ def execute(
13161318
13171319
Note: argument "multi" is unused.
13181320
"""
1321+
charset = self._connection.charset
1322+
if charset == "utf8mb4":
1323+
charset = "utf8"
1324+
1325+
if not isinstance(operation, str):
1326+
try:
1327+
operation = operation.decode(charset)
1328+
except UnicodeDecodeError as err:
1329+
raise ProgrammingError(str(err)) from err
1330+
1331+
if isinstance(params, dict):
1332+
replacement_keys = re.findall(RE_SQL_PYTHON_CAPTURE_PARAM_NAME, operation)
1333+
try:
1334+
# Replace params dict with params tuple in correct order.
1335+
params = tuple(params[key] for key in replacement_keys)
1336+
except KeyError as err:
1337+
raise ProgrammingError(
1338+
"Not all placeholders were found in the parameters dict"
1339+
) from err
1340+
# Convert %(name)s to ? before sending it to MySQL
1341+
operation = re.sub(RE_SQL_PYTHON_REPLACE_PARAM, "?", operation)
1342+
13191343
if operation is not self._executed:
13201344
if self._prepared:
13211345
self._connection.cmd_stmt_close(self._prepared["statement_id"])
1322-
13231346
self._executed = operation
1347+
13241348
try:
1325-
if not isinstance(operation, bytes):
1326-
charset = self._connection.charset
1327-
if charset == "utf8mb4":
1328-
charset = "utf8"
1329-
operation = operation.encode(charset)
1330-
except (UnicodeDecodeError, UnicodeEncodeError) as err:
1349+
operation = operation.encode(charset)
1350+
except UnicodeEncodeError as err:
13311351
raise ProgrammingError(str(err)) from err
13321352

1333-
# need to convert %s to ? before sending it to MySQL
13341353
if b"%s" in operation:
1354+
# Convert %s to ? before sending it to MySQL
13351355
operation = re.sub(RE_SQL_FIND_PARAM, b"?", operation)
13361356

13371357
try:

lib/mysql/connector/cursor_cext.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
RE_SQL_INSERT_STMT,
7878
RE_SQL_INSERT_VALUES,
7979
RE_SQL_ON_DUPLICATE,
80+
RE_SQL_PYTHON_CAPTURE_PARAM_NAME,
81+
RE_SQL_PYTHON_REPLACE_PARAM,
8082
RE_SQL_SPLIT_STMTS,
8183
)
8284
from .errorcode import CR_NO_RESULT_SET
@@ -1095,7 +1097,7 @@ def reset(self, free: bool = True) -> None:
10951097
def execute(
10961098
self,
10971099
operation: StrOrBytes,
1098-
params: Optional[ParamsSequenceType] = None,
1100+
params: Optional[ParamsSequenceOrDictType] = None,
10991101
multi: bool = False,
11001102
) -> None: # multi is unused
11011103
"""Prepare and execute a MySQL Prepared Statement
@@ -1119,23 +1121,40 @@ def execute(
11191121

11201122
self._cnx.handle_unread_result(prepared=True)
11211123

1124+
charset = self._cnx.charset
1125+
if charset == "utf8mb4":
1126+
charset = "utf8"
1127+
1128+
if not isinstance(operation, str):
1129+
try:
1130+
operation = operation.decode(charset)
1131+
except UnicodeDecodeError as err:
1132+
raise ProgrammingError(str(err)) from err
1133+
1134+
if isinstance(params, dict):
1135+
replacement_keys = re.findall(RE_SQL_PYTHON_CAPTURE_PARAM_NAME, operation)
1136+
try:
1137+
# Replace params dict with params tuple in correct order.
1138+
params = tuple(params[key] for key in replacement_keys)
1139+
except KeyError as err:
1140+
raise ProgrammingError(
1141+
"Not all placeholders were found in the parameters dict"
1142+
) from err
1143+
# Convert %(name)s to ? before sending it to MySQL
1144+
operation = re.sub(RE_SQL_PYTHON_REPLACE_PARAM, "?", operation)
1145+
11221146
if operation is not self._executed:
11231147
if self._stmt:
11241148
self._cnx.cmd_stmt_close(self._stmt)
1125-
11261149
self._executed = operation
11271150

11281151
try:
1129-
if not isinstance(operation, bytes):
1130-
charset = self._cnx.charset
1131-
if charset == "utf8mb4":
1132-
charset = "utf8"
1133-
operation = operation.encode(charset)
1134-
except (UnicodeDecodeError, UnicodeEncodeError) as err:
1152+
operation = operation.encode(charset)
1153+
except UnicodeEncodeError as err:
11351154
raise ProgrammingError(str(err)) from err
11361155

1137-
# need to convert %s to ? before sending it to MySQL
11381156
if b"%s" in operation:
1157+
# Convert %s to ? before sending it to MySQL
11391158
operation = re.sub(RE_SQL_FIND_PARAM, b"?", operation)
11401159

11411160
try:

tests/cext/test_cext_cursor.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"""Testing the C Extension cursors
3232
"""
3333

34+
import copy
3435
import datetime
3536
import decimal
3637
import logging
@@ -845,6 +846,69 @@ class CMySQLCursorPreparedTests(tests.CMySQLCursorTests):
845846
"POINT(21.2, 34.2), ?)"
846847
)
847848

849+
insert_dict_stmt = (
850+
"INSERT INTO {0} ("
851+
"my_null, "
852+
"my_bit, "
853+
"my_tinyint, "
854+
"my_smallint, "
855+
"my_mediumint, "
856+
"my_int, "
857+
"my_bigint, "
858+
"my_decimal, "
859+
"my_float, "
860+
"my_double, "
861+
"my_date, "
862+
"my_time, "
863+
"my_datetime, "
864+
"my_year, "
865+
"my_char, "
866+
"my_varchar, "
867+
"my_enum, "
868+
"my_geometry, "
869+
"my_blob) "
870+
"VALUES ("
871+
"%(my_null)s, "
872+
"B'1111100', "
873+
"%(my_tinyint)s, "
874+
"%(my_smallint)s, "
875+
"%(my_mediumint)s, "
876+
"%(my_int)s, "
877+
"%(my_bigint)s, "
878+
"%(my_decimal)s, "
879+
"%(my_float)s, "
880+
"%(my_double)s, "
881+
"%(my_date)s, "
882+
"%(my_time)s, "
883+
"%(my_datetime)s, "
884+
"%(my_year)s, "
885+
"%(my_char)s, "
886+
"%(my_varchar)s, "
887+
"%(my_enum)s, "
888+
"POINT(21.2, 34.2), "
889+
"%(my_blob)s)"
890+
)
891+
892+
insert_columns = (
893+
"my_null",
894+
"my_tinyint",
895+
"my_smallint",
896+
"my_mediumint",
897+
"my_int",
898+
"my_bigint",
899+
"my_decimal",
900+
"my_float",
901+
"my_double",
902+
"my_date",
903+
"my_time",
904+
"my_datetime",
905+
"my_year",
906+
"my_char",
907+
"my_varchar",
908+
"my_enum",
909+
"my_blob",
910+
)
911+
848912
data = (
849913
None,
850914
127,
@@ -939,11 +1003,32 @@ def test_fetchmany(self):
9391003
self.assertEqual(len(rows), 1)
9401004
self.assertEqual(rows[0][1:], self.exp)
9411005

1006+
def test_execute(self):
1007+
# Use dict as placeholders
1008+
data_dict = dict(zip(self.insert_columns, self.data))
1009+
1010+
self.cur.execute(self.insert_dict_stmt.format(self.tbl), data_dict)
1011+
self.cur.execute(f"SELECT * FROM {self.tbl}")
1012+
rows = self.cur.fetchall()
1013+
self.assertEqual(len(rows), 1)
1014+
self.assertEqual(rows[0][1:], self.exp)
1015+
9421016
def test_executemany(self):
9431017
data = [self.data[:], self.data[:]]
9441018
self.cur.executemany(self.insert_stmt.format(self.tbl), data)
945-
self.cur.execute("SELECT * FROM {0}".format(self.tbl))
1019+
self.cur.execute(f"SELECT * FROM {self.tbl}")
9461020
rows = self.cur.fetchall()
9471021
self.assertEqual(len(rows), 2)
9481022
self.assertEqual(rows[0][1:], self.exp)
9491023
self.assertEqual(rows[1][1:], self.exp)
1024+
1025+
# Use dict as placeholders
1026+
data_dict = dict(zip(self.insert_columns, self.data))
1027+
data = [data_dict, copy.deepcopy(data_dict)]
1028+
1029+
self.cur.executemany(self.insert_dict_stmt.format(self.tbl), data)
1030+
self.cur.execute(f"SELECT * FROM {self.tbl}")
1031+
rows = self.cur.fetchall()
1032+
self.assertEqual(len(rows), 4)
1033+
self.assertEqual(rows[0][1:], self.exp)
1034+
self.assertEqual(rows[1][1:], self.exp)

tests/test_cursor.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,29 @@ def test_execute(self):
13091309
self.assertEqual(statement_id, cur._prepared["statement_id"])
13101310
self.assertEqual(exp, cur.fetchone())
13111311

1312+
# Use dict as placeholders
1313+
data = {"value1": 4, "value2": 5}
1314+
exp = (6.4031242374328485,)
1315+
stmt = "SELECT SQRT(POW(%(value1)s, 2) + POW(%(value2)s, 2)) AS hypotenuse"
1316+
cur.execute(stmt, data)
1317+
# See BUG#31964167 about this change in 8.0.22
1318+
statement_id = 4 if tests.MYSQL_VERSION < (8, 0, 24) else 6
1319+
self.assertEqual(statement_id, cur._prepared["statement_id"])
1320+
self.assertEqual(exp, cur.fetchone())
1321+
1322+
# Re-use statement
1323+
data = {"value1": 3, "value2": 4}
1324+
exp = (5.0,)
1325+
cur.execute(stmt, data)
1326+
# See BUG#31964167 about this change in 8.0.22
1327+
statement_id = 5 if tests.MYSQL_VERSION < (8, 0, 24) else 7
1328+
self.assertEqual(statement_id, cur._prepared["statement_id"])
1329+
self.assertEqual(exp, cur.fetchone())
1330+
1331+
# Raise ProgrammingError if placeholder doesn't exist in dict
1332+
stmt = "SELECT SQRT(POW(%(value1)s, 2) + POW(%(unknown)s, 2)) AS hypotenuse"
1333+
self.assertRaises(errors.ProgrammingError, cur.execute, stmt, data)
1334+
13121335
def test_executemany(self):
13131336
cur = self.cnx.cursor(cursor_class=cursor.MySQLCursorPrepared)
13141337

@@ -1327,10 +1350,8 @@ def test_executemany(self):
13271350

13281351
tbl = "myconnpy_cursor"
13291352
self._test_execute_setup(self.cnx, tbl)
1330-
stmt_insert = "INSERT INTO {table} (col1,col2) VALUES (%s, %s)".format(
1331-
table=tbl
1332-
)
1333-
stmt_select = "SELECT col1,col2 FROM {table} ORDER BY col1".format(table=tbl)
1353+
stmt_insert = f"INSERT INTO {tbl} (col1, col2) VALUES (%s, %s)"
1354+
stmt_select = f"SELECT col1, col2 FROM {tbl} ORDER BY col1"
13341355

13351356
cur.executemany(stmt_insert, [(1, 100), (2, 200), (3, 300)])
13361357
self.assertEqual(3, cur.rowcount)
@@ -1346,10 +1367,22 @@ def test_executemany(self):
13461367
)
13471368

13481369
data = [(2,), (3,)]
1349-
stmt = "DELETE FROM {table} WHERE col1 = %s".format(table=tbl)
1370+
stmt = f"DELETE FROM {tbl} WHERE col1 = %s"
13501371
cur.executemany(stmt, data)
13511372
self.assertEqual(2, cur.rowcount)
13521373

1374+
# Use dict as placeholders
1375+
stmt_insert = f"INSERT INTO {tbl} (col1, col2) VALUES (%(col1)s, %(col2)s)"
1376+
cur.executemany(
1377+
stmt_insert,
1378+
[
1379+
{"col1": 3, "col2": 100},
1380+
{"col1": 4, "col2": 200},
1381+
{"col1": 5, "col2": 300},
1382+
],
1383+
)
1384+
self.assertEqual(3, cur.rowcount)
1385+
13531386
self._test_execute_cleanup(self.cnx, tbl)
13541387
cur.close()
13551388

0 commit comments

Comments
 (0)