25
25
26
26
import socket
27
27
28
+ from functools import wraps
29
+
28
30
from .authentication import MySQL41AuthPlugin
29
31
from .errors import InterfaceError , OperationalError , ProgrammingError
30
32
from .crud import Schema
@@ -67,6 +69,17 @@ def close(self):
67
69
self ._socket = None
68
70
69
71
72
+ def catch_network_exception (func ):
73
+ @wraps (func )
74
+ def wrapper (self , * args , ** kwargs ):
75
+ try :
76
+ return func (self , * args , ** kwargs )
77
+ except (socket .error , RuntimeError ):
78
+ self .disconnect ()
79
+ raise InterfaceError ("Cannot connect to host." )
80
+ return wrapper
81
+
82
+
70
83
class Connection (object ):
71
84
def __init__ (self , settings ):
72
85
self ._user = settings .get ("user" )
@@ -104,32 +117,39 @@ def _authenticate(self):
104
117
plugin .build_authentication_response (extra_data ))
105
118
self .protocol .read_auth_ok ()
106
119
120
+ @catch_network_exception
107
121
def send_sql (self , sql , * args ):
108
122
self .protocol .send_execute_statement ("sql" , sql , args )
109
123
124
+ @catch_network_exception
110
125
def send_insert (self , statement ):
111
126
self .protocol .send_insert (statement )
112
127
ids = None
113
128
if isinstance (statement , AddStatement ):
114
129
ids = statement ._ids
115
130
return Result (self , ids )
116
131
132
+ @catch_network_exception
117
133
def find (self , statement ):
118
134
self .protocol .send_find (statement )
119
135
return DocResult (self ) if statement ._doc_based else RowResult (self )
120
136
137
+ @catch_network_exception
121
138
def delete (self , statement ):
122
139
self .protocol .send_delete (statement )
123
140
return Result (self )
124
141
142
+ @catch_network_exception
125
143
def update (self , statement ):
126
144
self .protocol .send_update (statement )
127
145
return Result (self )
128
146
147
+ @catch_network_exception
129
148
def execute_nonquery (self , namespace , cmd , raise_on_fail = True , * args ):
130
149
self .protocol .send_execute_statement (namespace , cmd , args )
131
150
return Result (self )
132
151
152
+ @catch_network_exception
133
153
def execute_sql_scalar (self , sql , * args ):
134
154
self .protocol .send_execute_statement ("sql" , sql , args )
135
155
result = RowResult (self )
@@ -138,11 +158,34 @@ def execute_sql_scalar(self, sql, *args):
138
158
raise InterfaceError ("No data found" )
139
159
return result [0 ][0 ]
140
160
161
+ @catch_network_exception
141
162
def get_row_result (self , cmd , * args ):
142
163
self .protocol .send_execute_statement ("xplugin" , cmd , args )
143
164
return RowResult (self )
144
165
166
+ @catch_network_exception
167
+ def read_row (self , result ):
168
+ return self .protocol .read_row (result )
169
+
170
+ @catch_network_exception
171
+ def close_result (self , result ):
172
+ self .protocol .close_result (result )
173
+
174
+ @catch_network_exception
175
+ def get_column_metadata (self , result ):
176
+ return self .protocol .get_column_metadata (result )
177
+
178
+ def is_open (self ):
179
+ return self .stream ._socket is not None
180
+
181
+ def disconnect (self ):
182
+ if not self .is_open ():
183
+ return
184
+ self .stream .close ()
185
+
145
186
def close (self ):
187
+ if not self .is_open ():
188
+ return
146
189
if self ._active_result is not None :
147
190
self ._active_result .fetch_all ()
148
191
self .protocol .send_close ()
@@ -153,6 +196,7 @@ def close(self):
153
196
class XConnection (Connection ):
154
197
def __init__ (self , settings ):
155
198
super (XConnection , self ).__init__ (settings )
199
+ self .dependent_connections = []
156
200
self ._routers = settings .get ("routers" , [])
157
201
158
202
if 'host' in settings and settings ['host' ]:
@@ -205,6 +249,19 @@ def connect(self):
205
249
else :
206
250
raise InterfaceError ("Cannot connect to host: {0}" .format (error ))
207
251
252
+ def bind_connection (self , connection ):
253
+ self .dependent_connections .append (connection )
254
+
255
+ def close (self ):
256
+ while self .dependent_connections :
257
+ self .dependent_connections .pop ().close ()
258
+ super (XConnection , self ).close ()
259
+
260
+ def disconnect (self ):
261
+ while self .dependent_connections :
262
+ self .dependent_connections .pop ().disconnect ()
263
+ super (XConnection , self ).disconnect ()
264
+
208
265
209
266
class NodeConnection (Connection ):
210
267
def __init__ (self , settings ):
@@ -324,6 +381,14 @@ def __init__(self, settings):
324
381
self ._connection = XConnection (self ._settings )
325
382
self ._connection .connect ()
326
383
384
+ def bind_to_default_shard (self ):
385
+ if not self .is_open ():
386
+ raise OperationalError ("XSession is not connected to a farm." )
387
+
388
+ nsess = NodeSession (self ._settings )
389
+ self ._connection .bind_connection (nsess ._connection )
390
+ return nsess
391
+
327
392
328
393
class NodeSession (BaseSession ):
329
394
"""Enables interaction with a X Protocol enabled MySQL Server.
0 commit comments