|
29 | 29 | """Implementing communication with MySQL servers.
|
30 | 30 | """
|
31 | 31 |
|
| 32 | +from decimal import Decimal |
32 | 33 | from io import IOBase
|
| 34 | +import datetime |
33 | 35 | import logging
|
34 | 36 | import os
|
35 | 37 | import platform
|
36 | 38 | import socket
|
| 39 | +import struct |
37 | 40 | import time
|
38 | 41 | import warnings
|
39 | 42 |
|
40 | 43 | from .authentication import get_auth_plugin
|
41 | 44 | from .constants import (
|
42 |
| - ClientFlag, ServerCmd, ServerFlag, |
| 45 | + ClientFlag, ServerCmd, ServerFlag, FieldType, |
43 | 46 | flag_is_set, ShutdownType, NET_BUFFER_LENGTH
|
44 | 47 | )
|
45 | 48 |
|
|
52 | 55 | MySQLCursorBufferedNamedTuple)
|
53 | 56 | from .network import MySQLUnixSocket, MySQLTCPSocket
|
54 | 57 | from .protocol import MySQLProtocol
|
55 |
| -from .utils import int4store, linux_distribution |
| 58 | +from .utils import int1store, int4store, lc_int, linux_distribution |
56 | 59 | from .abstracts import MySQLConnectionAbstract
|
57 | 60 |
|
58 | 61 | logging.getLogger(__name__).addHandler(logging.NullHandler())
|
@@ -102,6 +105,7 @@ def __init__(self, *args, **kwargs):
|
102 | 105 | self._auth_plugin = None
|
103 | 106 | self._krb_service_principal = None
|
104 | 107 | self._pool_config_version = None
|
| 108 | + self._query_attrs_supported = False |
105 | 109 |
|
106 | 110 | self._columns_desc = []
|
107 | 111 |
|
@@ -179,6 +183,10 @@ def _do_handshake(self):
|
179 | 183 | if handshake['capabilities'] & ClientFlag.PLUGIN_AUTH:
|
180 | 184 | self.set_client_flags([ClientFlag.PLUGIN_AUTH])
|
181 | 185 |
|
| 186 | + if handshake['capabilities'] & ClientFlag.CLIENT_QUERY_ATTRIBUTES: |
| 187 | + self._query_attrs_supported = True |
| 188 | + self.set_client_flags([ClientFlag.CLIENT_QUERY_ATTRIBUTES]) |
| 189 | + |
182 | 190 | self._handshake = handshake
|
183 | 191 |
|
184 | 192 | def _do_auth(self, username=None, password=None, database=None,
|
@@ -755,8 +763,85 @@ def cmd_query(self, query, raw=False, buffered=False, raw_as_string=False):
|
755 | 763 |
|
756 | 764 | Returns a tuple()
|
757 | 765 | """
|
758 |
| - if not isinstance(query, bytes): |
759 |
| - query = query.encode('utf-8') |
| 766 | + if not isinstance(query, bytearray): |
| 767 | + if isinstance(query, str): |
| 768 | + query = query.encode('utf-8') |
| 769 | + query = bytearray(query) |
| 770 | + # Prepare query attrs |
| 771 | + charset = self.charset if self.charset != "utf8mb4" else "utf8" |
| 772 | + packet = bytearray() |
| 773 | + if not self._query_attrs_supported and self._query_attrs: |
| 774 | + warnings.warn( |
| 775 | + "This version of the server does not support Query Attributes", |
| 776 | + category=Warning) |
| 777 | + if self._client_flags & ClientFlag.CLIENT_QUERY_ATTRIBUTES: |
| 778 | + names = [] |
| 779 | + types = [] |
| 780 | + values = [] |
| 781 | + null_bitmap = [0] * ((len(self._query_attrs) + 7) // 8) |
| 782 | + for pos, attr_tuple in enumerate(self._query_attrs): |
| 783 | + value = attr_tuple[1] |
| 784 | + flags = 0 |
| 785 | + if value is None: |
| 786 | + null_bitmap[(pos // 8)] |= 1 << (pos % 8) |
| 787 | + types.append(int1store(FieldType.NULL) + |
| 788 | + int1store(flags)) |
| 789 | + continue |
| 790 | + elif isinstance(value, int): |
| 791 | + (packed, field_type, |
| 792 | + flags) = self._protocol._prepare_binary_integer(value) |
| 793 | + values.append(packed) |
| 794 | + elif isinstance(value, str): |
| 795 | + value = value.encode(charset) |
| 796 | + values.append(lc_int(len(value)) + value) |
| 797 | + field_type = FieldType.VARCHAR |
| 798 | + elif isinstance(value, bytes): |
| 799 | + values.append(lc_int(len(value)) + value) |
| 800 | + field_type = FieldType.BLOB |
| 801 | + elif isinstance(value, Decimal): |
| 802 | + values.append( |
| 803 | + lc_int(len(str(value).encode( |
| 804 | + charset))) + str(value).encode(charset)) |
| 805 | + field_type = FieldType.DECIMAL |
| 806 | + elif isinstance(value, float): |
| 807 | + values.append(struct.pack('<d', value)) |
| 808 | + field_type = FieldType.DOUBLE |
| 809 | + elif isinstance(value, (datetime.datetime, datetime.date)): |
| 810 | + (packed, field_type) = \ |
| 811 | + self._protocol._prepare_binary_timestamp(value) |
| 812 | + values.append(packed) |
| 813 | + elif isinstance(value, (datetime.timedelta, datetime.time)): |
| 814 | + (packed, field_type) = \ |
| 815 | + self._protocol._prepare_binary_time(value) |
| 816 | + values.append(packed) |
| 817 | + else: |
| 818 | + raise errors.ProgrammingError( |
| 819 | + "MySQL binary protocol can not handle " |
| 820 | + "'{classname}' objects".format( |
| 821 | + classname=value.__class__.__name__)) |
| 822 | + types.append(int1store(field_type) + |
| 823 | + int1store(flags)) |
| 824 | + name = attr_tuple[0].encode(charset) |
| 825 | + names.append(lc_int(len(name)) + name) |
| 826 | + |
| 827 | + # int<lenenc> parameter_count Number of parameters |
| 828 | + packet.extend(lc_int(len(self._query_attrs))) |
| 829 | + # int<lenenc> parameter_set_count Number of parameter sets. |
| 830 | + # Currently always 1 |
| 831 | + packet.extend(lc_int(1)) |
| 832 | + if values: |
| 833 | + packet.extend( |
| 834 | + b''.join([struct.pack('B', bit) for bit in null_bitmap]) + |
| 835 | + int1store(1)) |
| 836 | + for _type, name in zip(types, names): |
| 837 | + packet.extend(_type) |
| 838 | + packet.extend(name) |
| 839 | + |
| 840 | + for value in values: |
| 841 | + packet.extend(value) |
| 842 | + |
| 843 | + packet.extend(query) |
| 844 | + query = bytes(packet) |
760 | 845 | try:
|
761 | 846 | result = self._handle_result(self._send_cmd(ServerCmd.QUERY, query))
|
762 | 847 | except errors.ProgrammingError as err:
|
@@ -789,13 +874,23 @@ def cmd_query_iter(self, statements):
|
789 | 874 |
|
790 | 875 | Returns a generator.
|
791 | 876 | """
|
| 877 | + packet = bytearray() |
792 | 878 | if not isinstance(statements, bytearray):
|
793 | 879 | if isinstance(statements, str):
|
794 | 880 | statements = statements.encode('utf8')
|
795 | 881 | statements = bytearray(statements)
|
796 | 882 |
|
| 883 | + if self._client_flags & ClientFlag.CLIENT_QUERY_ATTRIBUTES: |
| 884 | + # int<lenenc> parameter_count Number of parameters |
| 885 | + packet.extend(lc_int(0)) |
| 886 | + # int<lenenc> parameter_set_count Number of parameter sets. |
| 887 | + # Currently always 1 |
| 888 | + packet.extend(lc_int(1)) |
| 889 | + |
| 890 | + packet.extend(statements) |
| 891 | + query = bytes(packet) |
797 | 892 | # Handle the first query result
|
798 |
| - yield self._handle_result(self._send_cmd(ServerCmd.QUERY, statements)) |
| 893 | + yield self._handle_result(self._send_cmd(ServerCmd.QUERY, query)) |
799 | 894 |
|
800 | 895 | # Handle next results, if any
|
801 | 896 | while self._have_next_result:
|
@@ -1266,10 +1361,18 @@ def cmd_stmt_execute(self, statement_id, data=(), parameters=(), flags=0):
|
1266 | 1361 | self.cmd_stmt_send_long_data(statement_id, param_id,
|
1267 | 1362 | data[param_id])
|
1268 | 1363 | long_data_used[param_id] = (binary,)
|
1269 |
| - |
1270 |
| - execute_packet = self._protocol.make_stmt_execute( |
1271 |
| - statement_id, data, tuple(parameters), flags, |
1272 |
| - long_data_used, self.charset) |
| 1364 | + if not self._query_attrs_supported and self._query_attrs: |
| 1365 | + warnings.warn( |
| 1366 | + "This version of the server does not support Query Attributes", |
| 1367 | + category=Warning) |
| 1368 | + if self._client_flags & ClientFlag.CLIENT_QUERY_ATTRIBUTES: |
| 1369 | + execute_packet = self._protocol.make_stmt_execute( |
| 1370 | + statement_id, data, tuple(parameters), flags, |
| 1371 | + long_data_used, self.charset, self._query_attrs) |
| 1372 | + else: |
| 1373 | + execute_packet = self._protocol.make_stmt_execute( |
| 1374 | + statement_id, data, tuple(parameters), flags, |
| 1375 | + long_data_used, self.charset) |
1273 | 1376 | packet = self._send_cmd(ServerCmd.STMT_EXECUTE, packet=execute_packet)
|
1274 | 1377 | result = self._handle_binary_result(packet)
|
1275 | 1378 | return result
|
|
0 commit comments