Skip to content

gh-89301: Fix regression with bound values in traced SQLite statements #92053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion Lib/test/test_sqlite3/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.

import unittest
import contextlib
import sqlite3 as sqlite
import unittest

from test.support.os_helper import TESTFN, unlink

from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from test.test_sqlite3.test_userfunctions import with_tracebacks


class CollationTests(unittest.TestCase):
def test_create_collation_not_string(self):
con = sqlite.connect(":memory:")
Expand Down Expand Up @@ -224,6 +228,16 @@ def bad_progress():


class TraceCallbackTests(unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
traced = []
cx.set_trace_callback(lambda stmt: traced.append(stmt))
yield
finally:
self.assertEqual(traced, expected)
cx.set_trace_callback(None)

def test_trace_callback_used(self):
"""
Test that the trace callback is invoked once it is set.
Expand Down Expand Up @@ -289,6 +303,52 @@ def trace(statement):
con2.close()
self.assertEqual(traced_statements, queries)

def test_trace_expanded_sql(self):
expected = [
"create table t(t)",
"BEGIN ",
"insert into t values(0)",
"insert into t values(1)",
"insert into t values(2)",
"COMMIT",
]
with memory_database() as cx, self.check_stmt_trace(cx, expected):
with cx:
cx.execute("create table t(t)")
cx.executemany("insert into t values(?)", ((v,) for v in range(3)))

@with_tracebacks(
sqlite.DataError,
regex="Expanded SQL string exceeds the maximum string length"
)
def test_trace_too_much_expanded_sql(self):
# If the expanded string is too large, we'll fall back to the
# unexpanded SQL statement (for SQLite 3.14.0 and newer).
# The resulting string length is limited by the runtime limit
# SQLITE_LIMIT_LENGTH.
template = "select 1 as a where a="
category = sqlite.SQLITE_LIMIT_LENGTH
with memory_database() as cx, cx_limit(cx, category=category) as lim:
ok_param = "a"
bad_param = "a" * lim

unexpanded_query = template + "?"
expected = [unexpanded_query]
if sqlite.sqlite_version_info < (3, 14, 0):
expected = []
with self.check_stmt_trace(cx, expected):
cx.execute(unexpanded_query, (bad_param,))

expanded_query = f"{template}'{ok_param}'"
with self.check_stmt_trace(cx, [expanded_query]):
cx.execute(unexpanded_query, (ok_param,))

@with_tracebacks(ZeroDivisionError, regex="division by zero")
def test_trace_bad_handler(self):
with memory_database() as cx:
cx.set_trace_callback(lambda stmt: 5/0)
cx.execute("select 1")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix a regression in the :mod:`sqlite3` trace callback where bound parameters
were not expanded in the passed statement string. The regression was introduced
in Python 3.10 by :issue:`40318`. Patch by Erlend E. Aasland.
54 changes: 40 additions & 14 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1332,11 +1332,10 @@ progress_callback(void *ctx)
* to ensure future compatibility.
*/
static int
trace_callback(unsigned int type, void *ctx, void *prepared_statement,
void *statement_string)
trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
#else
static void
trace_callback(void *ctx, const char *statement_string)
trace_callback(void *ctx, const char *sql)
#endif
{
#ifdef HAVE_TRACE_V2
Expand All @@ -1347,24 +1346,51 @@ trace_callback(void *ctx, const char *statement_string)

PyGILState_STATE gilstate = PyGILState_Ensure();

PyObject *py_statement = NULL;
PyObject *ret = NULL;
py_statement = PyUnicode_DecodeUTF8(statement_string,
strlen(statement_string), "replace");
assert(ctx != NULL);
pysqlite_state *state = ((callback_context *)ctx)->state;
assert(state != NULL);

PyObject *py_statement = NULL;
#ifdef HAVE_TRACE_V2
const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt);
if (expanded_sql == NULL) {
sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt);
if (sqlite3_errcode(db) == SQLITE_NOMEM) {
(void)PyErr_NoMemory();
goto exit;
}

PyErr_SetString(state->DataError,
"Expanded SQL string exceeds the maximum string length");
print_or_clear_traceback((callback_context *)ctx);

// Fall back to unexpanded sql
py_statement = PyUnicode_FromString((const char *)sql);
}
else {
py_statement = PyUnicode_FromString(expanded_sql);
sqlite3_free((void *)expanded_sql);
}
#else
if (sql == NULL) {
PyErr_SetString(state->DataError,
"Expanded SQL string exceeds the maximum string length");
print_or_clear_traceback((callback_context *)ctx);
goto exit;
}
py_statement = PyUnicode_FromString(sql);
#endif
if (py_statement) {
PyObject *callable = ((callback_context *)ctx)->callable;
ret = PyObject_CallOneArg(callable, py_statement);
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}

if (ret) {
Py_DECREF(ret);
}
else {
print_or_clear_traceback(ctx);
if (PyErr_Occurred()) {
print_or_clear_traceback((callback_context *)ctx);
}

exit:
PyGILState_Release(gilstate);
#ifdef HAVE_TRACE_V2
return 0;
Expand Down