24
24
from sqlalchemy import orm , event
25
25
from sqlalchemy .orm .exc import UnmappedClassError
26
26
from sqlalchemy .orm .session import Session
27
- from sqlalchemy .orm .interfaces import MapperExtension , SessionExtension , \
28
- EXT_CONTINUE
27
+ from sqlalchemy .event import listen
29
28
from sqlalchemy .interfaces import ConnectionProxy
30
29
from sqlalchemy .engine .url import make_url
31
30
from sqlalchemy .ext .declarative import declarative_base , DeclarativeMeta
@@ -91,7 +90,6 @@ def _include_sqlalchemy(obj):
91
90
setattr (obj , key , getattr (module , key ))
92
91
# Note: obj.Table does not attempt to be a SQLAlchemy Table class.
93
92
obj .Table = _make_table (obj )
94
- obj .mapper = signalling_mapper
95
93
obj .relationship = _wrap_with_default_query_class (obj .relationship )
96
94
obj .relation = _wrap_with_default_query_class (obj .relation )
97
95
obj .dynamic_loader = _wrap_with_default_query_class (obj .dynamic_loader )
@@ -155,67 +153,74 @@ def cursor_execute(self, execute, cursor, statement, parameters,
155
153
_calling_context (self .app_package ))))
156
154
157
155
158
- class _SignalTrackingMapperExtension (MapperExtension ):
159
-
160
- def after_delete (self , mapper , connection , instance ):
161
- return self ._record (mapper , instance , 'delete' )
156
+ class _SignallingSession (Session ):
162
157
163
- def after_insert (self , mapper , connection , instance ):
164
- return self ._record (mapper , instance , 'insert' )
158
+ def __init__ (self , db , autocommit = False , autoflush = False , ** options ):
159
+ self .app = db .get_app ()
160
+ self ._model_changes = {}
161
+ Session .__init__ (self , autocommit = autocommit , autoflush = autoflush ,
162
+ bind = db .engine ,
163
+ binds = db .get_binds (self .app ), ** options )
165
164
166
- def after_update (self , mapper , connection , instance ):
167
- return self ._record (mapper , instance , 'update' )
165
+ def get_bind (self , mapper , clause = None ):
166
+ # mapper is None if someone tries to just get a connection
167
+ if mapper is not None :
168
+ info = getattr (mapper .mapped_table , 'info' , {})
169
+ bind_key = info .get ('bind_key' )
170
+ if bind_key is not None :
171
+ state = get_state (self .app )
172
+ return state .db .get_engine (self .app , bind = bind_key )
173
+ return Session .get_bind (self , mapper , clause )
168
174
169
- def _record (self , mapper , model , operation ):
170
- s = orm .object_session (model )
171
- # Skip the operation tracking when a non signalling session
172
- # is used.
173
- if isinstance (s , _SignallingSessionExtension ):
174
- pk = tuple (mapper .primary_key_from_instance (model ))
175
- s ._model_changes [pk ] = (model , operation )
176
- return EXT_CONTINUE
177
175
176
+ class _SessionSignalEvents (object ):
178
177
179
- class _SignallingSessionExtension (SessionExtension ):
178
+ def register (self ):
179
+ listen (Session , 'before_commit' , self .session_signal_before_commit )
180
+ listen (Session , 'after_commit' , self .session_signal_after_commit )
181
+ listen (Session , 'after_rollback' , self .session_signal_after_rollback )
180
182
181
- def before_commit (self , session ):
183
+ @staticmethod
184
+ def session_signal_before_commit (session ):
182
185
d = session ._model_changes
183
186
if d :
184
187
before_models_committed .send (session .app , changes = d .values ())
185
- return EXT_CONTINUE
186
188
187
- def after_commit (self , session ):
189
+ @staticmethod
190
+ def session_signal_after_commit (session ):
188
191
d = session ._model_changes
189
192
if d :
190
193
models_committed .send (session .app , changes = d .values ())
191
194
d .clear ()
192
- return EXT_CONTINUE
193
195
194
- def after_rollback (self , session ):
196
+ @staticmethod
197
+ def session_signal_after_rollback (session ):
195
198
session ._model_changes .clear ()
196
- return EXT_CONTINUE
197
199
198
200
199
- class _SignallingSession ( Session ):
201
+ class _MapperSignalEvents ( object ):
200
202
201
- def __init__ (self , db , autocommit = False , autoflush = False , ** options ):
202
- self .app = db .get_app ()
203
- self ._model_changes = {}
204
- Session .__init__ (self , autocommit = autocommit , autoflush = autoflush ,
205
- extension = db .session_extensions ,
206
- bind = db .engine ,
207
- binds = db .get_binds (self .app ), ** options )
203
+ def __init__ (self , mapper ):
204
+ self .mapper = mapper
208
205
209
- def get_bind (self , mapper , clause = None ):
210
- # mapper is None if someone tries to just get a connection
211
- if mapper is not None :
212
- info = getattr (mapper .mapped_table , 'info' , {})
213
- bind_key = info .get ('bind_key' )
214
- if bind_key is not None :
215
- state = get_state (self .app )
216
- return state .db .get_engine (self .app , bind = bind_key )
217
- return Session .get_bind (self , mapper , clause )
206
+ def register (self ):
207
+ listen (self .mapper , 'after_delete' , self .mapper_signal_after_delete )
208
+ listen (self .mapper , 'after_insert' , self .mapper_signal_after_insert )
209
+ listen (self .mapper , 'after_update' , self .mapper_signal_after_update )
210
+
211
+ def mapper_signal_after_delete (self , mapper , connection , target ):
212
+ self ._record (mapper , target , 'delete' )
218
213
214
+ def mapper_signal_after_insert (self , mapper , connection , target ):
215
+ self ._record (mapper , target , 'insert' )
216
+
217
+ def mapper_signal_after_update (self , mapper , connection , target ):
218
+ self ._record (mapper , target , 'update' )
219
+
220
+ @staticmethod
221
+ def _record (mapper , target , operation ):
222
+ pk = tuple (mapper .primary_key_from_instance (target ))
223
+ orm .object_session (target )._model_changes [pk ] = (target , operation )
219
224
220
225
def get_debug_queries ():
221
226
"""In debug mode Flask-SQLAlchemy will log all the SQL queries sent
@@ -504,14 +509,6 @@ def get_state(app):
504
509
return app .extensions ['sqlalchemy' ]
505
510
506
511
507
- def signalling_mapper (* args , ** kwargs ):
508
- """Replacement for mapper that injects some extra extensions"""
509
- extensions = to_list (kwargs .pop ('extension' , None ), [])
510
- extensions .append (_SignalTrackingMapperExtension ())
511
- kwargs ['extension' ] = extensions
512
- return sqlalchemy .orm .mapper (* args , ** kwargs )
513
-
514
-
515
512
class _SQLAlchemyState (object ):
516
513
"""Remembers configuration for the (db, app) tuple."""
517
514
@@ -619,11 +616,10 @@ class User(db.Model):
619
616
a custom function which will define the SQLAlchemy session's scoping.
620
617
"""
621
618
622
- def __init__ (self , app = None , use_native_unicode = True ,
623
- session_extensions = None , session_options = None ):
619
+ def __init__ (self , app = None ,
620
+ use_native_unicode = True ,
621
+ session_options = None ):
624
622
self .use_native_unicode = use_native_unicode
625
- self .session_extensions = to_list (session_extensions , []) + \
626
- [_SignallingSessionExtension ()]
627
623
628
624
if session_options is None :
629
625
session_options = {}
@@ -643,6 +639,8 @@ def __init__(self, app=None, use_native_unicode=True,
643
639
self .app = None
644
640
645
641
_include_sqlalchemy (self )
642
+ _MapperSignalEvents (self .mapper ).register ()
643
+ _SessionSignalEvents ().register ()
646
644
self .Query = BaseQuery
647
645
648
646
@property
@@ -662,7 +660,6 @@ def create_scoped_session(self, options=None):
662
660
def make_declarative_base (self ):
663
661
"""Creates the declarative base."""
664
662
base = declarative_base (cls = Model , name = 'Model' ,
665
- mapper = signalling_mapper ,
666
663
metaclass = _BoundDeclarativeMeta )
667
664
base .query = _QueryProperty (self )
668
665
return base
0 commit comments