Skip to content

Fix and optimize fetching dict rows. #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/windows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ name: Build windows wheels

on:
push:
branches:
- master
workflow_dispatch:

jobs:
Expand Down
131 changes: 95 additions & 36 deletions MySQLdb/_mysql.c
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,8 @@ _mysql_field_to_python(
static PyObject *
_mysql_row_to_tuple(
_mysql_ResultObject *self,
MYSQL_ROW row)
MYSQL_ROW row,
PyObject *unused)
{
unsigned int n, i;
unsigned long *length;
Expand All @@ -1221,7 +1222,8 @@ _mysql_row_to_tuple(
static PyObject *
_mysql_row_to_dict(
_mysql_ResultObject *self,
MYSQL_ROW row)
MYSQL_ROW row,
PyObject *cache)
{
unsigned int n, i;
unsigned long *length;
Expand All @@ -1243,40 +1245,42 @@ _mysql_row_to_dict(
Py_DECREF(v);
goto error;
}

PyObject *tmp = PyDict_SetDefault(r, pyname, v);
Py_DECREF(pyname);
if (!tmp) {
int err = PyDict_Contains(r, pyname);
if (err < 0) { // error
Py_DECREF(v);
goto error;
}
if (tmp == v) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug. If v for t1.a and t2.a is same, PyDict_SetDefault may return tmp == v.

Py_DECREF(v);
continue;
if (err) { // duplicate
Py_DECREF(pyname);
pyname = PyUnicode_FromFormat("%s.%s", fields[i].table, fields[i].name);
if (pyname == NULL) {
Py_DECREF(v);
goto error;
}
}

pyname = PyUnicode_FromFormat("%s.%s", fields[i].table, fields[i].name);
if (!pyname) {
Py_DECREF(v);
goto error;
err = PyDict_SetItem(r, pyname, v);
if (cache) {
PyTuple_SET_ITEM(cache, i, pyname);
} else {
Py_DECREF(pyname);
}
int err = PyDict_SetItem(r, pyname, v);
Py_DECREF(pyname);
Py_DECREF(v);
if (err) {
goto error;
}
}
return r;
error:
Py_XDECREF(r);
error:
Py_DECREF(r);
return NULL;
}

static PyObject *
_mysql_row_to_dict_old(
_mysql_ResultObject *self,
MYSQL_ROW row)
MYSQL_ROW row,
PyObject *cache)
{
unsigned int n, i;
unsigned long *length;
Expand All @@ -1302,8 +1306,12 @@ _mysql_row_to_dict_old(
pyname = PyUnicode_FromString(fields[i].name);
}
int err = PyDict_SetItem(r, pyname, v);
Py_DECREF(pyname);
Py_DECREF(v);
if (cache) {
PyTuple_SET_ITEM(cache, i, pyname);
} else {
Py_DECREF(pyname);
}
if (err) {
goto error;
}
Expand All @@ -1314,15 +1322,66 @@ _mysql_row_to_dict_old(
return NULL;
}

typedef PyObject *_PYFUNC(_mysql_ResultObject *, MYSQL_ROW);
static PyObject *
_mysql_row_to_dict_cached(
_mysql_ResultObject *self,
MYSQL_ROW row,
PyObject *cache)
{
PyObject *r = PyDict_New();
if (!r) {
return NULL;
}

unsigned int n = mysql_num_fields(self->result);
unsigned long *length = mysql_fetch_lengths(self->result);
MYSQL_FIELD *fields = mysql_fetch_fields(self->result);

for (unsigned int i=0; i<n; i++) {
PyObject *c = PyTuple_GET_ITEM(self->converter, i);
PyObject *v = _mysql_field_to_python(c, row[i], length[i], &fields[i], self->encoding);
if (!v) {
goto error;
}

PyObject *pyname = PyTuple_GET_ITEM(cache, i); // borrowed
int err = PyDict_SetItem(r, pyname, v);
Py_DECREF(v);
if (err) {
goto error;
}
}
return r;
error:
Py_XDECREF(r);
return NULL;
}


typedef PyObject *_convertfunc(_mysql_ResultObject *, MYSQL_ROW, PyObject *);
static _convertfunc * const row_converters[] = {
_mysql_row_to_tuple,
_mysql_row_to_dict,
_mysql_row_to_dict_old
};

Py_ssize_t
_mysql__fetch_row(
_mysql_ResultObject *self,
PyObject *r, /* list object */
Py_ssize_t maxrows,
_PYFUNC *convert_row)
int how)
{
_convertfunc *convert_row = row_converters[how];

PyObject *cache = NULL;
if (maxrows > 0 && how > 0) {
cache = PyTuple_New(mysql_num_fields(self->result));
if (!cache) {
return -1;
}
}

Py_ssize_t i;
for (i = 0; i < maxrows; i++) {
MYSQL_ROW row;
Expand All @@ -1335,20 +1394,29 @@ _mysql__fetch_row(
}
if (!row && mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) {
_mysql_Exception((_mysql_ConnectionObject *)self->conn);
return -1;
goto error;
}
if (!row) {
break;
}
PyObject *v = convert_row(self, row);
if (!v) return -1;
PyObject *v = convert_row(self, row, cache);
if (!v) {
goto error;
}
if (cache) {
convert_row = _mysql_row_to_dict_cached;
}
if (PyList_Append(r, v)) {
Py_DECREF(v);
return -1;
goto error;
}
Py_DECREF(v);
}
Py_XDECREF(cache);
return i;
error:
Py_XDECREF(cache);
return -1;
}

static char _mysql_ResultObject_fetch_row__doc__[] =
Expand All @@ -1366,15 +1434,7 @@ _mysql_ResultObject_fetch_row(
PyObject *args,
PyObject *kwargs)
{
typedef PyObject *_PYFUNC(_mysql_ResultObject *, MYSQL_ROW);
static char *kwlist[] = { "maxrows", "how", NULL };
static _PYFUNC *row_converters[] =
{
_mysql_row_to_tuple,
_mysql_row_to_dict,
_mysql_row_to_dict_old
};
_PYFUNC *convert_row;
static char *kwlist[] = {"maxrows", "how", NULL };
int maxrows=1, how=0;
PyObject *r=NULL;

Expand All @@ -1386,7 +1446,6 @@ _mysql_ResultObject_fetch_row(
PyErr_SetString(PyExc_ValueError, "how out of range");
return NULL;
}
convert_row = row_converters[how];
if (!maxrows) {
if (self->use) {
maxrows = INT_MAX;
Expand All @@ -1396,7 +1455,7 @@ _mysql_ResultObject_fetch_row(
}
}
if (!(r = PyList_New(0))) goto error;
Py_ssize_t rowsadded = _mysql__fetch_row(self, r, maxrows, convert_row);
Py_ssize_t rowsadded = _mysql__fetch_row(self, r, maxrows, how);
if (rowsadded == -1) goto error;

/* DB-API allows return rows as list.
Expand Down
39 changes: 39 additions & 0 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,42 @@ def test_pyparam():
assert cursor._executed == b"SELECT 1, 2"
cursor.execute(b"SELECT %(a)s, %(b)s", {b"a": 3, b"b": 4})
assert cursor._executed == b"SELECT 3, 4"


def test_dictcursor():
conn = connect()
cursor = conn.cursor(MySQLdb.cursors.DictCursor)

cursor.execute("CREATE TABLE t1 (a int, b int, c int)")
_tables.append("t1")
cursor.execute("INSERT INTO t1 (a,b,c) VALUES (1,1,47), (2,2,47)")

cursor.execute("CREATE TABLE t2 (b int, c int)")
_tables.append("t2")
cursor.execute("INSERT INTO t2 (b,c) VALUES (1,1), (2,2)")

cursor.execute("SELECT * FROM t1 JOIN t2 ON t1.b=t2.b")
rows = cursor.fetchall()

assert len(rows) == 2
assert rows[0] == {"a": 1, "b": 1, "c": 47, "t2.b": 1, "t2.c": 1}
assert rows[1] == {"a": 2, "b": 2, "c": 47, "t2.b": 2, "t2.c": 2}

names1 = sorted(rows[0])
names2 = sorted(rows[1])
for a, b in zip(names1, names2):
assert a is b

# Old fetchtype
cursor._fetch_type = 2
cursor.execute("SELECT * FROM t1 JOIN t2 ON t1.b=t2.b")
rows = cursor.fetchall()

assert len(rows) == 2
assert rows[0] == {"t1.a": 1, "t1.b": 1, "t1.c": 47, "t2.b": 1, "t2.c": 1}
assert rows[1] == {"t1.a": 2, "t1.b": 2, "t1.c": 47, "t2.b": 2, "t2.c": 2}

names1 = sorted(rows[0])
names2 = sorted(rows[1])
for a, b in zip(names1, names2):
assert a is b