|
16 | 16 | import traceback
|
17 | 17 | import warnings
|
18 | 18 |
|
19 |
| -from .charset import MBLENGTH, charset_by_name, charset_by_id |
| 19 | +from .charset import charset_by_name, charset_by_id |
20 | 20 | from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
|
21 | 21 | from . import converters
|
22 | 22 | from .cursors import Cursor
|
23 | 23 | from .optionfile import Parser
|
| 24 | +from .protocol import ( |
| 25 | + dump_packet, MysqlPacket, FieldDescriptorPacket, OKPacketWrapper, |
| 26 | + EOFPacketWrapper, LoadLocalPacketWrapper |
| 27 | +) |
24 | 28 | from .util import byte2int, int2byte
|
25 | 29 | from . import err
|
26 | 30 |
|
@@ -85,42 +89,10 @@ def _makefile(sock, mode):
|
85 | 89 |
|
86 | 90 | sha_new = partial(hashlib.new, 'sha1')
|
87 | 91 |
|
88 |
| -NULL_COLUMN = 251 |
89 |
| -UNSIGNED_CHAR_COLUMN = 251 |
90 |
| -UNSIGNED_SHORT_COLUMN = 252 |
91 |
| -UNSIGNED_INT24_COLUMN = 253 |
92 |
| -UNSIGNED_INT64_COLUMN = 254 |
93 |
| - |
94 | 92 | DEFAULT_CHARSET = 'latin1'
|
95 | 93 |
|
96 | 94 | MAX_PACKET_LEN = 2**24-1
|
97 | 95 |
|
98 |
| - |
99 |
| -def dump_packet(data): # pragma: no cover |
100 |
| - def is_ascii(data): |
101 |
| - if 65 <= byte2int(data) <= 122: |
102 |
| - if isinstance(data, int): |
103 |
| - return chr(data) |
104 |
| - return data |
105 |
| - return '.' |
106 |
| - |
107 |
| - try: |
108 |
| - print("packet length:", len(data)) |
109 |
| - for i in range(1, 6): |
110 |
| - f = sys._getframe(i) |
111 |
| - print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno)) |
112 |
| - print("-" * 66) |
113 |
| - except ValueError: |
114 |
| - pass |
115 |
| - dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)] |
116 |
| - for d in dump_data: |
117 |
| - print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) + |
118 |
| - ' ' * (16 - len(d)) + ' ' * 2 + |
119 |
| - ''.join(map(lambda x: "{}".format(is_ascii(x)), d))) |
120 |
| - print("-" * 66) |
121 |
| - print() |
122 |
| - |
123 |
| - |
124 | 96 | SCRAMBLE_LENGTH = 20
|
125 | 97 |
|
126 | 98 | def _scramble(password, message):
|
@@ -214,297 +186,6 @@ def lenenc_int(i):
|
214 | 186 | else:
|
215 | 187 | raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
|
216 | 188 |
|
217 |
| -class MysqlPacket(object): |
218 |
| - """Representation of a MySQL response packet. |
219 |
| -
|
220 |
| - Provides an interface for reading/parsing the packet results. |
221 |
| - """ |
222 |
| - __slots__ = ('_position', '_data') |
223 |
| - |
224 |
| - def __init__(self, data, encoding): |
225 |
| - self._position = 0 |
226 |
| - self._data = data |
227 |
| - |
228 |
| - def get_all_data(self): |
229 |
| - return self._data |
230 |
| - |
231 |
| - def read(self, size): |
232 |
| - """Read the first 'size' bytes in packet and advance cursor past them.""" |
233 |
| - result = self._data[self._position:(self._position+size)] |
234 |
| - if len(result) != size: |
235 |
| - error = ('Result length not requested length:\n' |
236 |
| - 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' |
237 |
| - % (size, len(result), self._position, len(self._data))) |
238 |
| - if DEBUG: |
239 |
| - print(error) |
240 |
| - self.dump() |
241 |
| - raise AssertionError(error) |
242 |
| - self._position += size |
243 |
| - return result |
244 |
| - |
245 |
| - def read_all(self): |
246 |
| - """Read all remaining data in the packet. |
247 |
| -
|
248 |
| - (Subsequent read() will return errors.) |
249 |
| - """ |
250 |
| - result = self._data[self._position:] |
251 |
| - self._position = None # ensure no subsequent read() |
252 |
| - return result |
253 |
| - |
254 |
| - def advance(self, length): |
255 |
| - """Advance the cursor in data buffer 'length' bytes.""" |
256 |
| - new_position = self._position + length |
257 |
| - if new_position < 0 or new_position > len(self._data): |
258 |
| - raise Exception('Invalid advance amount (%s) for cursor. ' |
259 |
| - 'Position=%s' % (length, new_position)) |
260 |
| - self._position = new_position |
261 |
| - |
262 |
| - def rewind(self, position=0): |
263 |
| - """Set the position of the data buffer cursor to 'position'.""" |
264 |
| - if position < 0 or position > len(self._data): |
265 |
| - raise Exception("Invalid position to rewind cursor to: %s." % position) |
266 |
| - self._position = position |
267 |
| - |
268 |
| - def get_bytes(self, position, length=1): |
269 |
| - """Get 'length' bytes starting at 'position'. |
270 |
| -
|
271 |
| - Position is start of payload (first four packet header bytes are not |
272 |
| - included) starting at index '0'. |
273 |
| -
|
274 |
| - No error checking is done. If requesting outside end of buffer |
275 |
| - an empty string (or string shorter than 'length') may be returned! |
276 |
| - """ |
277 |
| - return self._data[position:(position+length)] |
278 |
| - |
279 |
| - if PY2: |
280 |
| - def read_uint8(self): |
281 |
| - result = ord(self._data[self._position]) |
282 |
| - self._position += 1 |
283 |
| - return result |
284 |
| - else: |
285 |
| - def read_uint8(self): |
286 |
| - result = self._data[self._position] |
287 |
| - self._position += 1 |
288 |
| - return result |
289 |
| - |
290 |
| - def read_uint16(self): |
291 |
| - result = struct.unpack_from('<H', self._data, self._position)[0] |
292 |
| - self._position += 2 |
293 |
| - return result |
294 |
| - |
295 |
| - def read_uint24(self): |
296 |
| - low, high = struct.unpack_from('<HB', self._data, self._position) |
297 |
| - self._position += 3 |
298 |
| - return low + (high << 16) |
299 |
| - |
300 |
| - def read_uint32(self): |
301 |
| - result = struct.unpack_from('<I', self._data, self._position)[0] |
302 |
| - self._position += 4 |
303 |
| - return result |
304 |
| - |
305 |
| - def read_uint64(self): |
306 |
| - result = struct.unpack_from('<Q', self._data, self._position)[0] |
307 |
| - self._position += 8 |
308 |
| - return result |
309 |
| - |
310 |
| - def read_string(self): |
311 |
| - end_pos = self._data.find(b'\0', self._position) |
312 |
| - if end_pos < 0: |
313 |
| - return None |
314 |
| - result = self._data[self._position:end_pos] |
315 |
| - self._position = end_pos + 1 |
316 |
| - return result |
317 |
| - |
318 |
| - def read_length_encoded_integer(self): |
319 |
| - """Read a 'Length Coded Binary' number from the data buffer. |
320 |
| -
|
321 |
| - Length coded numbers can be anywhere from 1 to 9 bytes depending |
322 |
| - on the value of the first byte. |
323 |
| - """ |
324 |
| - c = self.read_uint8() |
325 |
| - if c == NULL_COLUMN: |
326 |
| - return None |
327 |
| - if c < UNSIGNED_CHAR_COLUMN: |
328 |
| - return c |
329 |
| - elif c == UNSIGNED_SHORT_COLUMN: |
330 |
| - return self.read_uint16() |
331 |
| - elif c == UNSIGNED_INT24_COLUMN: |
332 |
| - return self.read_uint24() |
333 |
| - elif c == UNSIGNED_INT64_COLUMN: |
334 |
| - return self.read_uint64() |
335 |
| - |
336 |
| - def read_length_coded_string(self): |
337 |
| - """Read a 'Length Coded String' from the data buffer. |
338 |
| -
|
339 |
| - A 'Length Coded String' consists first of a length coded |
340 |
| - (unsigned, positive) integer represented in 1-9 bytes followed by |
341 |
| - that many bytes of binary data. (For example "cat" would be "3cat".) |
342 |
| - """ |
343 |
| - length = self.read_length_encoded_integer() |
344 |
| - if length is None: |
345 |
| - return None |
346 |
| - return self.read(length) |
347 |
| - |
348 |
| - def read_struct(self, fmt): |
349 |
| - s = struct.Struct(fmt) |
350 |
| - result = s.unpack_from(self._data, self._position) |
351 |
| - self._position += s.size |
352 |
| - return result |
353 |
| - |
354 |
| - def is_ok_packet(self): |
355 |
| - # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html |
356 |
| - return self._data[0:1] == b'\0' and len(self._data) >= 7 |
357 |
| - |
358 |
| - def is_eof_packet(self): |
359 |
| - # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet |
360 |
| - # Caution: \xFE may be LengthEncodedInteger. |
361 |
| - # If \xFE is LengthEncodedInteger header, 8bytes followed. |
362 |
| - return self._data[0:1] == b'\xfe' and len(self._data) < 9 |
363 |
| - |
364 |
| - def is_auth_switch_request(self): |
365 |
| - # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest |
366 |
| - return self._data[0:1] == b'\xfe' |
367 |
| - |
368 |
| - def is_resultset_packet(self): |
369 |
| - field_count = ord(self._data[0:1]) |
370 |
| - return 1 <= field_count <= 250 |
371 |
| - |
372 |
| - def is_load_local_packet(self): |
373 |
| - return self._data[0:1] == b'\xfb' |
374 |
| - |
375 |
| - def is_error_packet(self): |
376 |
| - return self._data[0:1] == b'\xff' |
377 |
| - |
378 |
| - def check_error(self): |
379 |
| - if self.is_error_packet(): |
380 |
| - self.rewind() |
381 |
| - self.advance(1) # field_count == error (we already know that) |
382 |
| - errno = self.read_uint16() |
383 |
| - if DEBUG: print("errno =", errno) |
384 |
| - err.raise_mysql_exception(self._data) |
385 |
| - |
386 |
| - def dump(self): |
387 |
| - dump_packet(self._data) |
388 |
| - |
389 |
| - |
390 |
| -class FieldDescriptorPacket(MysqlPacket): |
391 |
| - """A MysqlPacket that represents a specific column's metadata in the result. |
392 |
| -
|
393 |
| - Parsing is automatically done and the results are exported via public |
394 |
| - attributes on the class such as: db, table_name, name, length, type_code. |
395 |
| - """ |
396 |
| - |
397 |
| - def __init__(self, data, encoding): |
398 |
| - MysqlPacket.__init__(self, data, encoding) |
399 |
| - self._parse_field_descriptor(encoding) |
400 |
| - |
401 |
| - def _parse_field_descriptor(self, encoding): |
402 |
| - """Parse the 'Field Descriptor' (Metadata) packet. |
403 |
| -
|
404 |
| - This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). |
405 |
| - """ |
406 |
| - self.catalog = self.read_length_coded_string() |
407 |
| - self.db = self.read_length_coded_string() |
408 |
| - self.table_name = self.read_length_coded_string().decode(encoding) |
409 |
| - self.org_table = self.read_length_coded_string().decode(encoding) |
410 |
| - self.name = self.read_length_coded_string().decode(encoding) |
411 |
| - self.org_name = self.read_length_coded_string().decode(encoding) |
412 |
| - self.charsetnr, self.length, self.type_code, self.flags, self.scale = ( |
413 |
| - self.read_struct('<xHIBHBxx')) |
414 |
| - # 'default' is a length coded binary and is still in the buffer? |
415 |
| - # not used for normal result sets... |
416 |
| - |
417 |
| - def description(self): |
418 |
| - """Provides a 7-item tuple compatible with the Python PEP249 DB Spec.""" |
419 |
| - return ( |
420 |
| - self.name, |
421 |
| - self.type_code, |
422 |
| - None, # TODO: display_length; should this be self.length? |
423 |
| - self.get_column_length(), # 'internal_size' |
424 |
| - self.get_column_length(), # 'precision' # TODO: why!?!? |
425 |
| - self.scale, |
426 |
| - self.flags % 2 == 0) |
427 |
| - |
428 |
| - def get_column_length(self): |
429 |
| - if self.type_code == FIELD_TYPE.VAR_STRING: |
430 |
| - mblen = MBLENGTH.get(self.charsetnr, 1) |
431 |
| - return self.length // mblen |
432 |
| - return self.length |
433 |
| - |
434 |
| - def __str__(self): |
435 |
| - return ('%s %r.%r.%r, type=%s, flags=%x' |
436 |
| - % (self.__class__, self.db, self.table_name, self.name, |
437 |
| - self.type_code, self.flags)) |
438 |
| - |
439 |
| - |
440 |
| -class OKPacketWrapper(object): |
441 |
| - """ |
442 |
| - OK Packet Wrapper. It uses an existing packet object, and wraps |
443 |
| - around it, exposing useful variables while still providing access |
444 |
| - to the original packet objects variables and methods. |
445 |
| - """ |
446 |
| - |
447 |
| - def __init__(self, from_packet): |
448 |
| - if not from_packet.is_ok_packet(): |
449 |
| - raise ValueError('Cannot create ' + str(self.__class__.__name__) + |
450 |
| - ' object from invalid packet type') |
451 |
| - |
452 |
| - self.packet = from_packet |
453 |
| - self.packet.advance(1) |
454 |
| - |
455 |
| - self.affected_rows = self.packet.read_length_encoded_integer() |
456 |
| - self.insert_id = self.packet.read_length_encoded_integer() |
457 |
| - self.server_status, self.warning_count = self.read_struct('<HH') |
458 |
| - self.message = self.packet.read_all() |
459 |
| - self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS |
460 |
| - |
461 |
| - def __getattr__(self, key): |
462 |
| - return getattr(self.packet, key) |
463 |
| - |
464 |
| - |
465 |
| -class EOFPacketWrapper(object): |
466 |
| - """ |
467 |
| - EOF Packet Wrapper. It uses an existing packet object, and wraps |
468 |
| - around it, exposing useful variables while still providing access |
469 |
| - to the original packet objects variables and methods. |
470 |
| - """ |
471 |
| - |
472 |
| - def __init__(self, from_packet): |
473 |
| - if not from_packet.is_eof_packet(): |
474 |
| - raise ValueError( |
475 |
| - "Cannot create '{0}' object from invalid packet type".format( |
476 |
| - self.__class__)) |
477 |
| - |
478 |
| - self.packet = from_packet |
479 |
| - self.warning_count, self.server_status = self.packet.read_struct('<xhh') |
480 |
| - if DEBUG: print("server_status=", self.server_status) |
481 |
| - self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS |
482 |
| - |
483 |
| - def __getattr__(self, key): |
484 |
| - return getattr(self.packet, key) |
485 |
| - |
486 |
| - |
487 |
| -class LoadLocalPacketWrapper(object): |
488 |
| - """ |
489 |
| - Load Local Packet Wrapper. It uses an existing packet object, and wraps |
490 |
| - around it, exposing useful variables while still providing access |
491 |
| - to the original packet objects variables and methods. |
492 |
| - """ |
493 |
| - |
494 |
| - def __init__(self, from_packet): |
495 |
| - if not from_packet.is_load_local_packet(): |
496 |
| - raise ValueError( |
497 |
| - "Cannot create '{0}' object from invalid packet type".format( |
498 |
| - self.__class__)) |
499 |
| - |
500 |
| - self.packet = from_packet |
501 |
| - self.filename = self.packet.get_all_data()[1:] |
502 |
| - if DEBUG: print("filename=", self.filename) |
503 |
| - |
504 |
| - def __getattr__(self, key): |
505 |
| - return getattr(self.packet, key) |
506 |
| - |
507 |
| - |
508 | 189 | class Connection(object):
|
509 | 190 | """
|
510 | 191 | Representation of a socket with a mysql server.
|
|
0 commit comments