Skip to content

Commit 14e4c25

Browse files
authored
Split connections module to protocol (#670)
1 parent 9105a9e commit 14e4c25

File tree

2 files changed

+341
-324
lines changed

2 files changed

+341
-324
lines changed

pymysql/connections.py

Lines changed: 5 additions & 324 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
import traceback
1717
import warnings
1818

19-
from .charset import MBLENGTH, charset_by_name, charset_by_id
19+
from .charset import charset_by_name, charset_by_id
2020
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
2121
from . import converters
2222
from .cursors import Cursor
2323
from .optionfile import Parser
24+
from .protocol import (
25+
dump_packet, MysqlPacket, FieldDescriptorPacket, OKPacketWrapper,
26+
EOFPacketWrapper, LoadLocalPacketWrapper
27+
)
2428
from .util import byte2int, int2byte
2529
from . import err
2630

@@ -85,42 +89,10 @@ def _makefile(sock, mode):
8589

8690
sha_new = partial(hashlib.new, 'sha1')
8791

88-
NULL_COLUMN = 251
89-
UNSIGNED_CHAR_COLUMN = 251
90-
UNSIGNED_SHORT_COLUMN = 252
91-
UNSIGNED_INT24_COLUMN = 253
92-
UNSIGNED_INT64_COLUMN = 254
93-
9492
DEFAULT_CHARSET = 'latin1'
9593

9694
MAX_PACKET_LEN = 2**24-1
9795

98-
99-
def dump_packet(data): # pragma: no cover
100-
def is_ascii(data):
101-
if 65 <= byte2int(data) <= 122:
102-
if isinstance(data, int):
103-
return chr(data)
104-
return data
105-
return '.'
106-
107-
try:
108-
print("packet length:", len(data))
109-
for i in range(1, 6):
110-
f = sys._getframe(i)
111-
print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
112-
print("-" * 66)
113-
except ValueError:
114-
pass
115-
dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
116-
for d in dump_data:
117-
print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
118-
' ' * (16 - len(d)) + ' ' * 2 +
119-
''.join(map(lambda x: "{}".format(is_ascii(x)), d)))
120-
print("-" * 66)
121-
print()
122-
123-
12496
SCRAMBLE_LENGTH = 20
12597

12698
def _scramble(password, message):
@@ -214,297 +186,6 @@ def lenenc_int(i):
214186
else:
215187
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
216188

217-
class MysqlPacket(object):
218-
"""Representation of a MySQL response packet.
219-
220-
Provides an interface for reading/parsing the packet results.
221-
"""
222-
__slots__ = ('_position', '_data')
223-
224-
def __init__(self, data, encoding):
225-
self._position = 0
226-
self._data = data
227-
228-
def get_all_data(self):
229-
return self._data
230-
231-
def read(self, size):
232-
"""Read the first 'size' bytes in packet and advance cursor past them."""
233-
result = self._data[self._position:(self._position+size)]
234-
if len(result) != size:
235-
error = ('Result length not requested length:\n'
236-
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
237-
% (size, len(result), self._position, len(self._data)))
238-
if DEBUG:
239-
print(error)
240-
self.dump()
241-
raise AssertionError(error)
242-
self._position += size
243-
return result
244-
245-
def read_all(self):
246-
"""Read all remaining data in the packet.
247-
248-
(Subsequent read() will return errors.)
249-
"""
250-
result = self._data[self._position:]
251-
self._position = None # ensure no subsequent read()
252-
return result
253-
254-
def advance(self, length):
255-
"""Advance the cursor in data buffer 'length' bytes."""
256-
new_position = self._position + length
257-
if new_position < 0 or new_position > len(self._data):
258-
raise Exception('Invalid advance amount (%s) for cursor. '
259-
'Position=%s' % (length, new_position))
260-
self._position = new_position
261-
262-
def rewind(self, position=0):
263-
"""Set the position of the data buffer cursor to 'position'."""
264-
if position < 0 or position > len(self._data):
265-
raise Exception("Invalid position to rewind cursor to: %s." % position)
266-
self._position = position
267-
268-
def get_bytes(self, position, length=1):
269-
"""Get 'length' bytes starting at 'position'.
270-
271-
Position is start of payload (first four packet header bytes are not
272-
included) starting at index '0'.
273-
274-
No error checking is done. If requesting outside end of buffer
275-
an empty string (or string shorter than 'length') may be returned!
276-
"""
277-
return self._data[position:(position+length)]
278-
279-
if PY2:
280-
def read_uint8(self):
281-
result = ord(self._data[self._position])
282-
self._position += 1
283-
return result
284-
else:
285-
def read_uint8(self):
286-
result = self._data[self._position]
287-
self._position += 1
288-
return result
289-
290-
def read_uint16(self):
291-
result = struct.unpack_from('<H', self._data, self._position)[0]
292-
self._position += 2
293-
return result
294-
295-
def read_uint24(self):
296-
low, high = struct.unpack_from('<HB', self._data, self._position)
297-
self._position += 3
298-
return low + (high << 16)
299-
300-
def read_uint32(self):
301-
result = struct.unpack_from('<I', self._data, self._position)[0]
302-
self._position += 4
303-
return result
304-
305-
def read_uint64(self):
306-
result = struct.unpack_from('<Q', self._data, self._position)[0]
307-
self._position += 8
308-
return result
309-
310-
def read_string(self):
311-
end_pos = self._data.find(b'\0', self._position)
312-
if end_pos < 0:
313-
return None
314-
result = self._data[self._position:end_pos]
315-
self._position = end_pos + 1
316-
return result
317-
318-
def read_length_encoded_integer(self):
319-
"""Read a 'Length Coded Binary' number from the data buffer.
320-
321-
Length coded numbers can be anywhere from 1 to 9 bytes depending
322-
on the value of the first byte.
323-
"""
324-
c = self.read_uint8()
325-
if c == NULL_COLUMN:
326-
return None
327-
if c < UNSIGNED_CHAR_COLUMN:
328-
return c
329-
elif c == UNSIGNED_SHORT_COLUMN:
330-
return self.read_uint16()
331-
elif c == UNSIGNED_INT24_COLUMN:
332-
return self.read_uint24()
333-
elif c == UNSIGNED_INT64_COLUMN:
334-
return self.read_uint64()
335-
336-
def read_length_coded_string(self):
337-
"""Read a 'Length Coded String' from the data buffer.
338-
339-
A 'Length Coded String' consists first of a length coded
340-
(unsigned, positive) integer represented in 1-9 bytes followed by
341-
that many bytes of binary data. (For example "cat" would be "3cat".)
342-
"""
343-
length = self.read_length_encoded_integer()
344-
if length is None:
345-
return None
346-
return self.read(length)
347-
348-
def read_struct(self, fmt):
349-
s = struct.Struct(fmt)
350-
result = s.unpack_from(self._data, self._position)
351-
self._position += s.size
352-
return result
353-
354-
def is_ok_packet(self):
355-
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
356-
return self._data[0:1] == b'\0' and len(self._data) >= 7
357-
358-
def is_eof_packet(self):
359-
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
360-
# Caution: \xFE may be LengthEncodedInteger.
361-
# If \xFE is LengthEncodedInteger header, 8bytes followed.
362-
return self._data[0:1] == b'\xfe' and len(self._data) < 9
363-
364-
def is_auth_switch_request(self):
365-
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
366-
return self._data[0:1] == b'\xfe'
367-
368-
def is_resultset_packet(self):
369-
field_count = ord(self._data[0:1])
370-
return 1 <= field_count <= 250
371-
372-
def is_load_local_packet(self):
373-
return self._data[0:1] == b'\xfb'
374-
375-
def is_error_packet(self):
376-
return self._data[0:1] == b'\xff'
377-
378-
def check_error(self):
379-
if self.is_error_packet():
380-
self.rewind()
381-
self.advance(1) # field_count == error (we already know that)
382-
errno = self.read_uint16()
383-
if DEBUG: print("errno =", errno)
384-
err.raise_mysql_exception(self._data)
385-
386-
def dump(self):
387-
dump_packet(self._data)
388-
389-
390-
class FieldDescriptorPacket(MysqlPacket):
391-
"""A MysqlPacket that represents a specific column's metadata in the result.
392-
393-
Parsing is automatically done and the results are exported via public
394-
attributes on the class such as: db, table_name, name, length, type_code.
395-
"""
396-
397-
def __init__(self, data, encoding):
398-
MysqlPacket.__init__(self, data, encoding)
399-
self._parse_field_descriptor(encoding)
400-
401-
def _parse_field_descriptor(self, encoding):
402-
"""Parse the 'Field Descriptor' (Metadata) packet.
403-
404-
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
405-
"""
406-
self.catalog = self.read_length_coded_string()
407-
self.db = self.read_length_coded_string()
408-
self.table_name = self.read_length_coded_string().decode(encoding)
409-
self.org_table = self.read_length_coded_string().decode(encoding)
410-
self.name = self.read_length_coded_string().decode(encoding)
411-
self.org_name = self.read_length_coded_string().decode(encoding)
412-
self.charsetnr, self.length, self.type_code, self.flags, self.scale = (
413-
self.read_struct('<xHIBHBxx'))
414-
# 'default' is a length coded binary and is still in the buffer?
415-
# not used for normal result sets...
416-
417-
def description(self):
418-
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
419-
return (
420-
self.name,
421-
self.type_code,
422-
None, # TODO: display_length; should this be self.length?
423-
self.get_column_length(), # 'internal_size'
424-
self.get_column_length(), # 'precision' # TODO: why!?!?
425-
self.scale,
426-
self.flags % 2 == 0)
427-
428-
def get_column_length(self):
429-
if self.type_code == FIELD_TYPE.VAR_STRING:
430-
mblen = MBLENGTH.get(self.charsetnr, 1)
431-
return self.length // mblen
432-
return self.length
433-
434-
def __str__(self):
435-
return ('%s %r.%r.%r, type=%s, flags=%x'
436-
% (self.__class__, self.db, self.table_name, self.name,
437-
self.type_code, self.flags))
438-
439-
440-
class OKPacketWrapper(object):
441-
"""
442-
OK Packet Wrapper. It uses an existing packet object, and wraps
443-
around it, exposing useful variables while still providing access
444-
to the original packet objects variables and methods.
445-
"""
446-
447-
def __init__(self, from_packet):
448-
if not from_packet.is_ok_packet():
449-
raise ValueError('Cannot create ' + str(self.__class__.__name__) +
450-
' object from invalid packet type')
451-
452-
self.packet = from_packet
453-
self.packet.advance(1)
454-
455-
self.affected_rows = self.packet.read_length_encoded_integer()
456-
self.insert_id = self.packet.read_length_encoded_integer()
457-
self.server_status, self.warning_count = self.read_struct('<HH')
458-
self.message = self.packet.read_all()
459-
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
460-
461-
def __getattr__(self, key):
462-
return getattr(self.packet, key)
463-
464-
465-
class EOFPacketWrapper(object):
466-
"""
467-
EOF Packet Wrapper. It uses an existing packet object, and wraps
468-
around it, exposing useful variables while still providing access
469-
to the original packet objects variables and methods.
470-
"""
471-
472-
def __init__(self, from_packet):
473-
if not from_packet.is_eof_packet():
474-
raise ValueError(
475-
"Cannot create '{0}' object from invalid packet type".format(
476-
self.__class__))
477-
478-
self.packet = from_packet
479-
self.warning_count, self.server_status = self.packet.read_struct('<xhh')
480-
if DEBUG: print("server_status=", self.server_status)
481-
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
482-
483-
def __getattr__(self, key):
484-
return getattr(self.packet, key)
485-
486-
487-
class LoadLocalPacketWrapper(object):
488-
"""
489-
Load Local Packet Wrapper. It uses an existing packet object, and wraps
490-
around it, exposing useful variables while still providing access
491-
to the original packet objects variables and methods.
492-
"""
493-
494-
def __init__(self, from_packet):
495-
if not from_packet.is_load_local_packet():
496-
raise ValueError(
497-
"Cannot create '{0}' object from invalid packet type".format(
498-
self.__class__))
499-
500-
self.packet = from_packet
501-
self.filename = self.packet.get_all_data()[1:]
502-
if DEBUG: print("filename=", self.filename)
503-
504-
def __getattr__(self, key):
505-
return getattr(self.packet, key)
506-
507-
508189
class Connection(object):
509190
"""
510191
Representation of a socket with a mysql server.

0 commit comments

Comments
 (0)