Skip to content

Commit 1d41cb6

Browse files
committed
DevAPI: Per ReplicaSet SQL execution
XSession can return a NodeSession which shares the connection to the router with the XSession used to create it. It is currently a placeholder method which will be modified in the future. Tests were added for each scenario.
1 parent 21d00b3 commit 1d41cb6

File tree

3 files changed

+126
-3
lines changed

3 files changed

+126
-3
lines changed

lib/mysqlx/connection.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import socket
2727

28+
from functools import wraps
29+
2830
from .authentication import MySQL41AuthPlugin
2931
from .errors import InterfaceError, OperationalError, ProgrammingError
3032
from .crud import Schema
@@ -67,6 +69,17 @@ def close(self):
6769
self._socket = None
6870

6971

72+
def catch_network_exception(func):
73+
@wraps(func)
74+
def wrapper(self, *args, **kwargs):
75+
try:
76+
return func(self, *args, **kwargs)
77+
except (socket.error, RuntimeError):
78+
self.disconnect()
79+
raise InterfaceError("Cannot connect to host.")
80+
return wrapper
81+
82+
7083
class Connection(object):
7184
def __init__(self, settings):
7285
self._user = settings.get("user")
@@ -104,32 +117,39 @@ def _authenticate(self):
104117
plugin.build_authentication_response(extra_data))
105118
self.protocol.read_auth_ok()
106119

120+
@catch_network_exception
107121
def send_sql(self, sql, *args):
108122
self.protocol.send_execute_statement("sql", sql, args)
109123

124+
@catch_network_exception
110125
def send_insert(self, statement):
111126
self.protocol.send_insert(statement)
112127
ids = None
113128
if isinstance(statement, AddStatement):
114129
ids = statement._ids
115130
return Result(self, ids)
116131

132+
@catch_network_exception
117133
def find(self, statement):
118134
self.protocol.send_find(statement)
119135
return DocResult(self) if statement._doc_based else RowResult(self)
120136

137+
@catch_network_exception
121138
def delete(self, statement):
122139
self.protocol.send_delete(statement)
123140
return Result(self)
124141

142+
@catch_network_exception
125143
def update(self, statement):
126144
self.protocol.send_update(statement)
127145
return Result(self)
128146

147+
@catch_network_exception
129148
def execute_nonquery(self, namespace, cmd, raise_on_fail=True, *args):
130149
self.protocol.send_execute_statement(namespace, cmd, args)
131150
return Result(self)
132151

152+
@catch_network_exception
133153
def execute_sql_scalar(self, sql, *args):
134154
self.protocol.send_execute_statement("sql", sql, args)
135155
result = RowResult(self)
@@ -138,11 +158,34 @@ def execute_sql_scalar(self, sql, *args):
138158
raise InterfaceError("No data found")
139159
return result[0][0]
140160

161+
@catch_network_exception
141162
def get_row_result(self, cmd, *args):
142163
self.protocol.send_execute_statement("xplugin", cmd, args)
143164
return RowResult(self)
144165

166+
@catch_network_exception
167+
def read_row(self, result):
168+
return self.protocol.read_row(result)
169+
170+
@catch_network_exception
171+
def close_result(self, result):
172+
self.protocol.close_result(result)
173+
174+
@catch_network_exception
175+
def get_column_metadata(self, result):
176+
return self.protocol.get_column_metadata(result)
177+
178+
def is_open(self):
179+
return self.stream._socket is not None
180+
181+
def disconnect(self):
182+
if not self.is_open():
183+
return
184+
self.stream.close()
185+
145186
def close(self):
187+
if not self.is_open():
188+
return
146189
if self._active_result is not None:
147190
self._active_result.fetch_all()
148191
self.protocol.send_close()
@@ -153,6 +196,7 @@ def close(self):
153196
class XConnection(Connection):
154197
def __init__(self, settings):
155198
super(XConnection, self).__init__(settings)
199+
self.dependent_connections = []
156200
self._routers = settings.get("routers", [])
157201

158202
if 'host' in settings and settings['host']:
@@ -205,6 +249,19 @@ def connect(self):
205249
else:
206250
raise InterfaceError("Cannot connect to host: {0}".format(error))
207251

252+
def bind_connection(self, connection):
253+
self.dependent_connections.append(connection)
254+
255+
def close(self):
256+
while self.dependent_connections:
257+
self.dependent_connections.pop().close()
258+
super(XConnection, self).close()
259+
260+
def disconnect(self):
261+
while self.dependent_connections:
262+
self.dependent_connections.pop().disconnect()
263+
super(XConnection, self).disconnect()
264+
208265

209266
class NodeConnection(Connection):
210267
def __init__(self, settings):
@@ -324,6 +381,14 @@ def __init__(self, settings):
324381
self._connection = XConnection(self._settings)
325382
self._connection.connect()
326383

384+
def bind_to_default_shard(self):
385+
if not self.is_open():
386+
raise OperationalError("XSession is not connected to a farm.")
387+
388+
nsess = NodeSession(self._settings)
389+
self._connection.bind_connection(nsess._connection)
390+
return nsess
391+
327392

328393
class NodeSession(BaseSession):
329394
"""Enables interaction with a X Protocol enabled MySQL Server.

lib/mysqlx/result.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def __init__(self, connection=None, ids=None):
615615
self._ids = ids
616616

617617
if connection is not None:
618-
self._protocol.close_result(self)
618+
self._connection.close_result(self)
619619

620620
def get_affected_items_count(self):
621621
"""Returns the number of affected items for the last operation.
@@ -654,7 +654,7 @@ def __init__(self, connection):
654654
self._init_result()
655655

656656
def _init_result(self):
657-
self._columns = self._protocol.get_column_metadata(self)
657+
self._columns = self._connection.get_column_metadata(self)
658658
self._has_more_data = True if len(self._columns) > 0 else False
659659
self._items = []
660660
self._page_size = 20
@@ -684,7 +684,7 @@ def index_of(self, col_name):
684684
return -1
685685

686686
def _read_item(self, dumping):
687-
row = self._protocol.read_row(self)
687+
row = self._connection.read_row(self)
688688
if row is None:
689689
return None
690690
item = [None] * len(row.field)

tests/test_mysqlx_connection.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,64 @@ def test_close(self):
209209
session.close()
210210
self.assertRaises(mysqlx.OperationalError, schema.exists_in_database)
211211

212+
def test_bind_to_default_shard(self):
213+
try:
214+
# Getting a NodeSession to the default shard
215+
sess = mysqlx.get_session(self.connect_kwargs)
216+
nsess = sess.bind_to_default_shard()
217+
self.assertEqual(sess._settings, nsess._settings)
218+
219+
# Close XSession and all dependent NodeSessions
220+
sess.close()
221+
self.assertFalse(nsess.is_open())
222+
223+
# Connection error on XSession
224+
sess = mysqlx.get_session(self.connect_kwargs)
225+
nsess_a = sess.bind_to_default_shard()
226+
nsess_b = sess.bind_to_default_shard()
227+
tests.MYSQL_SERVERS[0].stop()
228+
tests.MYSQL_SERVERS[0].wait_down()
229+
230+
self.assertRaises(mysqlx.errors.InterfaceError,
231+
sess.get_default_schema().exists_in_database)
232+
self.assertFalse(sess.is_open())
233+
self.assertFalse(nsess_a.is_open())
234+
self.assertFalse(nsess_b.is_open())
235+
236+
tests.MYSQL_SERVERS[0].start()
237+
tests.MYSQL_SERVERS[0].wait_up()
238+
239+
# Connection error on dependent NodeSession
240+
sess = mysqlx.get_session(self.connect_kwargs)
241+
nsess_a = sess.bind_to_default_shard()
242+
nsess_b = sess.bind_to_default_shard()
243+
tests.MYSQL_SERVERS[0].stop()
244+
tests.MYSQL_SERVERS[0].wait_down()
245+
246+
self.assertRaises(mysqlx.errors.InterfaceError,
247+
nsess_a.sql("SELECT 1").execute)
248+
self.assertFalse(nsess_a.is_open())
249+
self.assertTrue(nsess_b.is_open())
250+
self.assertTrue(sess.is_open())
251+
252+
tests.MYSQL_SERVERS[0].start()
253+
tests.MYSQL_SERVERS[0].wait_up()
254+
255+
# Getting a NodeSession a shard (connect error)
256+
sess = mysqlx.get_session(self.connect_kwargs)
257+
tests.MYSQL_SERVERS[0].stop()
258+
tests.MYSQL_SERVERS[0].wait_down()
259+
260+
self.assertRaises(mysqlx.errors.InterfaceError,
261+
sess.bind_to_default_shard)
262+
263+
tests.MYSQL_SERVERS[0].start()
264+
tests.MYSQL_SERVERS[0].wait_up()
265+
266+
finally:
267+
if not tests.MYSQL_SERVERS[0].check_running():
268+
tests.MYSQL_SERVERS[0].start()
269+
tests.MYSQL_SERVERS[0].wait_up()
212270

213271
@unittest.skipIf(tests.MYSQL_VERSION < (5, 7, 12), "XPlugin not compatible")
214272
class MySQLxNodeSessionTests(tests.MySQLxTests):

0 commit comments

Comments
 (0)