Skip to content

Commit b3b2826

Browse files
committed
added DictCursor support. fixed a previously unknown bug in executemany(). fixed bug in python3 version of cursor.description
1 parent cfce952 commit b3b2826

File tree

4 files changed

+104
-11
lines changed

4 files changed

+104
-11
lines changed

pymysql/connections.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ class MysqlPacket(object):
189189
from the network socket, removes packet header and provides an interface
190190
for reading/parsing the packet results."""
191191

192-
def __init__(self, socket):
192+
def __init__(self, socket, connection):
193+
self.connection = connection
193194
self.__position = 0
194195
self.__recv_packet(socket)
195196
del socket
@@ -354,7 +355,7 @@ def __parse_field_descriptor(self):
354355
self.db = self.read_length_coded_string()
355356
self.table_name = self.read_length_coded_string()
356357
self.org_table = self.read_length_coded_string()
357-
self.name = self.read_length_coded_string()
358+
self.name = self.read_length_coded_string().decode(self.connection.charset)
358359
self.org_name = self.read_length_coded_string()
359360
self.advance(1) # non-null filler
360361
self.charsetnr = struct.unpack('<H', self.read(2))[0]
@@ -578,8 +579,10 @@ def literal(self, obj):
578579
''' Alias for escape() '''
579580
return escape_item(obj, self.charset)
580581

581-
def cursor(self):
582+
def cursor(self, cursor=None):
582583
''' Create a new cursor to execute queries with '''
584+
if cursor:
585+
return cursor(self)
583586
return self.cursorclass(self)
584587

585588
def __enter__(self):
@@ -676,7 +679,7 @@ def read_packet(self, packet_type=MysqlPacket):
676679
# socket.recv(large_number)? if so, maybe we should buffer
677680
# the socket.recv() (though that obviously makes memory management
678681
# more complicated.
679-
packet = packet_type(self.socket)
682+
packet = packet_type(self.socket, self)
680683
packet.check_error()
681684
return packet
682685

@@ -762,7 +765,7 @@ def _send_authentication(self):
762765

763766
sock.send(data)
764767

765-
auth_packet = MysqlPacket(sock)
768+
auth_packet = MysqlPacket(sock, self)
766769
auth_packet.check_error()
767770
if DEBUG: auth_packet.dump()
768771

@@ -777,7 +780,7 @@ def _send_authentication(self):
777780
data = pack_int24(len(data)) + int2byte(next_packet) + data
778781

779782
sock.send(data)
780-
auth_packet = MysqlPacket(sock)
783+
auth_packet = MysqlPacket(sock, self)
781784
auth_packet.check_error()
782785
if DEBUG: auth_packet.dump()
783786

@@ -798,7 +801,7 @@ def get_proto_info(self):
798801
def _get_server_information(self):
799802
sock = self.socket
800803
i = 0
801-
packet = MysqlPacket(sock)
804+
packet = MysqlPacket(sock, self)
802805
data = packet.get_all_data()
803806

804807
if DEBUG: dump_packet(data)

pymysql/cursors.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ def execute(self, query, args=None):
113113
def executemany(self, query, args):
114114
''' Run several data against one query '''
115115
del self.messages[:]
116-
conn = self._get_db()
116+
#conn = self._get_db()
117117
if not args:
118118
return
119-
charset = conn.charset
120-
if isinstance(query, unicode):
121-
query = query.encode(charset)
119+
#charset = conn.charset
120+
#if isinstance(query, unicode):
121+
# query = query.encode(charset)
122122

123123
self.rowcount = sum([ self.execute(query, arg) for arg in args ])
124124
return self.rowcount
@@ -248,3 +248,44 @@ def __iter__(self):
248248
InternalError = InternalError
249249
ProgrammingError = ProgrammingError
250250
NotSupportedError = NotSupportedError
251+
252+
class DictCursor(Cursor):
253+
"""A cursor which returns results as a dictionary"""
254+
255+
def execute(self, query, args=None):
256+
result = super(DictCursor, self).execute(query, args)
257+
if self.description:
258+
self._fields = [ field[0] for field in self.description ]
259+
return result
260+
261+
def fetchone(self):
262+
''' Fetch the next row '''
263+
self._check_executed()
264+
if self._rows is None or self.rownumber >= len(self._rows):
265+
return None
266+
result = dict(zip(self._fields, self._rows[self.rownumber]))
267+
self.rownumber += 1
268+
return result
269+
270+
def fetchmany(self, size=None):
271+
''' Fetch several rows '''
272+
self._check_executed()
273+
if self._rows is None:
274+
return None
275+
end = self.rownumber + (size or self.arraysize)
276+
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
277+
self.rownumber = min(end, len(self._rows))
278+
return tuple(result)
279+
280+
def fetchall(self):
281+
''' Fetch all the rows '''
282+
self._check_executed()
283+
if self._rows is None:
284+
return None
285+
if self.rownumber:
286+
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
287+
else:
288+
result = [ dict(zip(self._fields, r)) for r in self._rows ]
289+
self.rownumber = len(self._rows)
290+
return tuple(result)
291+

pymysql/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pymysql.tests.test_issues import *
22
from pymysql.tests.test_example import *
33
from pymysql.tests.test_basic import *
4+
from pymysql.tests.test_DictCursor import *
45

56
if __name__ == "__main__":
67
import unittest

pymysql/tests/test_DictCursor.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pymysql.tests import base
2+
import pymysql.cursors
3+
4+
import datetime
5+
6+
class TestDictCursor(base.PyMySQLTestCase):
7+
8+
def test_DictCursor(self):
9+
#all assert test compare to the structure as would come out from MySQLdb
10+
conn = self.connections[0]
11+
c = conn.cursor(pymysql.cursors.DictCursor)
12+
# create a table ane some data to query
13+
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
14+
data = (("bob",21,"1990-02-06 23:04:56"),
15+
("jim",56,"1955-05-09 13:12:45"),
16+
("fred",100,"1911-09-12 01:01:01"))
17+
bob = {'name':'bob','age':21,'DOB':datetime.datetime(1990, 02, 6, 23, 04, 56)}
18+
jim = {'name':'jim','age':56,'DOB':datetime.datetime(1955, 05, 9, 13, 12, 45)}
19+
fred = {'name':'fred','age':100,'DOB':datetime.datetime(1911, 9, 12, 1, 1, 1)}
20+
try:
21+
c.executemany("insert into dictcursor values (%s,%s,%s)", data)
22+
# try an update which should return no rows
23+
c.execute("update dictcursor set age=20 where name='bob'")
24+
bob['age'] = 20
25+
# pull back the single row dict for bob and check
26+
c.execute("SELECT * from dictcursor where name='bob'")
27+
r = c.fetchone()
28+
self.assertEqual(bob,r,"fetchone via DictCursor failed")
29+
# same again, but via fetchall => tuple)
30+
c.execute("SELECT * from dictcursor where name='bob'")
31+
r = c.fetchall()
32+
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
33+
# get all 3 row via fetchall
34+
c.execute("SELECT * from dictcursor")
35+
r = c.fetchall()
36+
self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor")
37+
# get all 2 row via fetchmany
38+
c.execute("SELECT * from dictcursor")
39+
r = c.fetchmany(2)
40+
self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor")
41+
finally:
42+
c.execute("drop table dictcursor")
43+
44+
__all__ = ["TestDictCursor"]
45+
46+
if __name__ == "__main__":
47+
import unittest
48+
unittest.main()

0 commit comments

Comments
 (0)