Skip to content

Commit a5b0510

Browse files
committed
Merge remote-tracking branch 'thrisp/sqlalchemy_events_revision'
2 parents 21b4add + eb208b5 commit a5b0510

File tree

2 files changed

+64
-100
lines changed

2 files changed

+64
-100
lines changed

flask_sqlalchemy.py

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from sqlalchemy import orm, event
2525
from sqlalchemy.orm.exc import UnmappedClassError
2626
from sqlalchemy.orm.session import Session
27-
from sqlalchemy.orm.interfaces import MapperExtension, SessionExtension, \
28-
EXT_CONTINUE
27+
from sqlalchemy.event import listen
2928
from sqlalchemy.interfaces import ConnectionProxy
3029
from sqlalchemy.engine.url import make_url
3130
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
@@ -91,7 +90,6 @@ def _include_sqlalchemy(obj):
9190
setattr(obj, key, getattr(module, key))
9291
# Note: obj.Table does not attempt to be a SQLAlchemy Table class.
9392
obj.Table = _make_table(obj)
94-
obj.mapper = signalling_mapper
9593
obj.relationship = _wrap_with_default_query_class(obj.relationship)
9694
obj.relation = _wrap_with_default_query_class(obj.relation)
9795
obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader)
@@ -155,67 +153,74 @@ def cursor_execute(self, execute, cursor, statement, parameters,
155153
_calling_context(self.app_package))))
156154

157155

158-
class _SignalTrackingMapperExtension(MapperExtension):
159-
160-
def after_delete(self, mapper, connection, instance):
161-
return self._record(mapper, instance, 'delete')
156+
class _SignallingSession(Session):
162157

163-
def after_insert(self, mapper, connection, instance):
164-
return self._record(mapper, instance, 'insert')
158+
def __init__(self, db, autocommit=False, autoflush=False, **options):
159+
self.app = db.get_app()
160+
self._model_changes = {}
161+
Session.__init__(self, autocommit=autocommit, autoflush=autoflush,
162+
bind=db.engine,
163+
binds=db.get_binds(self.app), **options)
165164

166-
def after_update(self, mapper, connection, instance):
167-
return self._record(mapper, instance, 'update')
165+
def get_bind(self, mapper, clause=None):
166+
# mapper is None if someone tries to just get a connection
167+
if mapper is not None:
168+
info = getattr(mapper.mapped_table, 'info', {})
169+
bind_key = info.get('bind_key')
170+
if bind_key is not None:
171+
state = get_state(self.app)
172+
return state.db.get_engine(self.app, bind=bind_key)
173+
return Session.get_bind(self, mapper, clause)
168174

169-
def _record(self, mapper, model, operation):
170-
s = orm.object_session(model)
171-
# Skip the operation tracking when a non signalling session
172-
# is used.
173-
if isinstance(s, _SignallingSessionExtension):
174-
pk = tuple(mapper.primary_key_from_instance(model))
175-
s._model_changes[pk] = (model, operation)
176-
return EXT_CONTINUE
177175

176+
class _SessionSignalEvents(object):
178177

179-
class _SignallingSessionExtension(SessionExtension):
178+
def register(self):
179+
listen(Session, 'before_commit', self.session_signal_before_commit)
180+
listen(Session, 'after_commit', self.session_signal_after_commit)
181+
listen(Session, 'after_rollback', self.session_signal_after_rollback)
180182

181-
def before_commit(self, session):
183+
@staticmethod
184+
def session_signal_before_commit(session):
182185
d = session._model_changes
183186
if d:
184187
before_models_committed.send(session.app, changes=d.values())
185-
return EXT_CONTINUE
186188

187-
def after_commit(self, session):
189+
@staticmethod
190+
def session_signal_after_commit(session):
188191
d = session._model_changes
189192
if d:
190193
models_committed.send(session.app, changes=d.values())
191194
d.clear()
192-
return EXT_CONTINUE
193195

194-
def after_rollback(self, session):
196+
@staticmethod
197+
def session_signal_after_rollback(session):
195198
session._model_changes.clear()
196-
return EXT_CONTINUE
197199

198200

199-
class _SignallingSession(Session):
201+
class _MapperSignalEvents(object):
200202

201-
def __init__(self, db, autocommit=False, autoflush=False, **options):
202-
self.app = db.get_app()
203-
self._model_changes = {}
204-
Session.__init__(self, autocommit=autocommit, autoflush=autoflush,
205-
extension=db.session_extensions,
206-
bind=db.engine,
207-
binds=db.get_binds(self.app), **options)
203+
def __init__(self, mapper):
204+
self.mapper = mapper
208205

209-
def get_bind(self, mapper, clause=None):
210-
# mapper is None if someone tries to just get a connection
211-
if mapper is not None:
212-
info = getattr(mapper.mapped_table, 'info', {})
213-
bind_key = info.get('bind_key')
214-
if bind_key is not None:
215-
state = get_state(self.app)
216-
return state.db.get_engine(self.app, bind=bind_key)
217-
return Session.get_bind(self, mapper, clause)
206+
def register(self):
207+
listen(self.mapper, 'after_delete', self.mapper_signal_after_delete)
208+
listen(self.mapper, 'after_insert', self.mapper_signal_after_insert)
209+
listen(self.mapper, 'after_update', self.mapper_signal_after_update)
210+
211+
def mapper_signal_after_delete(self, mapper, connection, target):
212+
self._record(mapper, target, 'delete')
218213

214+
def mapper_signal_after_insert(self, mapper, connection, target):
215+
self._record(mapper, target, 'insert')
216+
217+
def mapper_signal_after_update(self, mapper, connection, target):
218+
self._record(mapper, target, 'update')
219+
220+
@staticmethod
221+
def _record(mapper, target, operation):
222+
pk = tuple(mapper.primary_key_from_instance(target))
223+
orm.object_session(target)._model_changes[pk] = (target, operation)
219224

220225
def get_debug_queries():
221226
"""In debug mode Flask-SQLAlchemy will log all the SQL queries sent
@@ -504,14 +509,6 @@ def get_state(app):
504509
return app.extensions['sqlalchemy']
505510

506511

507-
def signalling_mapper(*args, **kwargs):
508-
"""Replacement for mapper that injects some extra extensions"""
509-
extensions = to_list(kwargs.pop('extension', None), [])
510-
extensions.append(_SignalTrackingMapperExtension())
511-
kwargs['extension'] = extensions
512-
return sqlalchemy.orm.mapper(*args, **kwargs)
513-
514-
515512
class _SQLAlchemyState(object):
516513
"""Remembers configuration for the (db, app) tuple."""
517514

@@ -619,11 +616,10 @@ class User(db.Model):
619616
a custom function which will define the SQLAlchemy session's scoping.
620617
"""
621618

622-
def __init__(self, app=None, use_native_unicode=True,
623-
session_extensions=None, session_options=None):
619+
def __init__(self, app=None,
620+
use_native_unicode=True,
621+
session_options=None):
624622
self.use_native_unicode = use_native_unicode
625-
self.session_extensions = to_list(session_extensions, []) + \
626-
[_SignallingSessionExtension()]
627623

628624
if session_options is None:
629625
session_options = {}
@@ -643,6 +639,8 @@ def __init__(self, app=None, use_native_unicode=True,
643639
self.app = None
644640

645641
_include_sqlalchemy(self)
642+
_MapperSignalEvents(self.mapper).register()
643+
_SessionSignalEvents().register()
646644
self.Query = BaseQuery
647645

648646
@property
@@ -662,7 +660,6 @@ def create_scoped_session(self, options=None):
662660
def make_declarative_base(self):
663661
"""Creates the declarative base."""
664662
base = declarative_base(cls=Model, name='Model',
665-
mapper=signalling_mapper,
666663
metaclass=_BoundDeclarativeMeta)
667664
base.query = _QueryProperty(self)
668665
return base

test_sqlalchemy.py

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from datetime import datetime
77
import flask
88
from flask.ext import sqlalchemy
9-
from sqlalchemy.orm import sessionmaker
109

1110

1211
def make_todo_model(db):
@@ -71,11 +70,11 @@ def test_query_recording(self):
7170
queries = sqlalchemy.get_debug_queries()
7271
self.assertEqual(len(queries), 1)
7372
query = queries[0]
74-
self.assert_('insert into' in query.statement.lower())
73+
self.assertTrue('insert into' in query.statement.lower())
7574
self.assertEqual(query.parameters[0], 'Test 1')
7675
self.assertEqual(query.parameters[1], 'test')
77-
self.assert_('test_sqlalchemy.py' in query.context)
78-
self.assert_('test_query_recording' in query.context)
76+
self.assertTrue('test_sqlalchemy.py' in query.context)
77+
self.assertTrue('test_query_recording' in query.context)
7978

8079
def test_helper_api(self):
8180
self.assertEqual(self.db.metadata, self.db.Model.metadata)
@@ -132,7 +131,7 @@ def tearDown(self):
132131
def test_model_signals(self):
133132
recorded = []
134133
def committed(sender, changes):
135-
self.assert_(isinstance(changes, list))
134+
self.assertTrue(isinstance(changes, list))
136135
recorded.extend(changes)
137136
with sqlalchemy.models_committed.connected_to(committed,
138137
sender=self.app):
@@ -179,7 +178,7 @@ def test_basic_pagination(self):
179178
p = sqlalchemy.Pagination(None, 1, 20, 500, [])
180179
self.assertEqual(p.page, 1)
181180
self.assertFalse(p.has_prev)
182-
self.assert_(p.has_next)
181+
self.assertTrue(p.has_next)
183182
self.assertEqual(p.total, 500)
184183
self.assertEqual(p.pages, 25)
185184
self.assertEqual(p.next_num, 2)
@@ -250,17 +249,17 @@ class Baz(db.Model):
250249
metadata = db.MetaData()
251250
metadata.reflect(bind=db.get_engine(app, 'foo'))
252251
self.assertEqual(len(metadata.tables), 1)
253-
self.assert_('foo' in metadata.tables)
252+
self.assertTrue('foo' in metadata.tables)
254253

255254
metadata = db.MetaData()
256255
metadata.reflect(bind=db.get_engine(app, 'bar'))
257256
self.assertEqual(len(metadata.tables), 1)
258-
self.assert_('bar' in metadata.tables)
257+
self.assertTrue('bar' in metadata.tables)
259258

260259
metadata = db.MetaData()
261260
metadata.reflect(bind=db.get_engine(app))
262261
self.assertEqual(len(metadata.tables), 1)
263-
self.assert_('baz' in metadata.tables)
262+
self.assertTrue('baz' in metadata.tables)
264263

265264
# do the session have the right binds set?
266265
self.assertEqual(db.get_binds(app), {
@@ -280,7 +279,7 @@ def test_default_query_class(self):
280279

281280
class Parent(db.Model):
282281
id = db.Column(db.Integer, primary_key=True)
283-
children = db.relationship("Child", backref = db.backref("parents", lazy='dynamic'), lazy='dynamic')
282+
children = db.relationship("Child", backref = "parents", lazy='dynamic')
284283
class Child(db.Model):
285284
id = db.Column(db.Integer, primary_key=True)
286285
parent_id = db.Column(db.Integer, db.ForeignKey('parent.id'))
@@ -289,8 +288,8 @@ class Child(db.Model):
289288
c.parent = p
290289
self.assertEqual(type(Parent.query), sqlalchemy.BaseQuery)
291290
self.assertEqual(type(Child.query), sqlalchemy.BaseQuery)
292-
self.assert_(isinstance(p.children, sqlalchemy.BaseQuery))
293-
self.assert_(isinstance(c.parents, sqlalchemy.BaseQuery))
291+
self.assertTrue(isinstance(p.children, sqlalchemy.BaseQuery))
292+
#self.assertTrue(isinstance(c.parents, sqlalchemy.BaseQuery))
294293

295294

296295
class SQLAlchemyIncludesTestCase(unittest.TestCase):
@@ -419,37 +418,6 @@ class FOOBar(db.Model):
419418
db.session.add(fb)
420419
assert fb not in db.session # because a new scope is generated on each call
421420

422-
class StandardSessionTestCase(unittest.TestCase):
423-
def test_insert_update_delete(self):
424-
# Ensure _SignalTrackingMapperExtension doesn't croak when
425-
# faced with a vanilla SQLAlchemy session.
426-
#
427-
# Verifies that "AttributeError: 'SessionMaker' object has no attribute '_model_changes'"
428-
# is not thrown.
429-
app = flask.Flask(__name__)
430-
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
431-
app.config['TESTING'] = True
432-
db = sqlalchemy.SQLAlchemy(app)
433-
Session = sessionmaker(bind=db.engine)
434-
435-
class QazWsx(db.Model):
436-
id = db.Column(db.Integer, primary_key=True)
437-
x = db.Column(db.String, default='')
438-
439-
db.create_all()
440-
session = Session()
441-
session.add(QazWsx())
442-
session.flush() # issues an INSERT.
443-
session.expunge_all()
444-
qaz_wsx = session.query(QazWsx).first()
445-
assert qaz_wsx.x == ''
446-
qaz_wsx.x = 'test'
447-
session.flush() # issues an UPDATE.
448-
session.expunge_all()
449-
qaz_wsx = session.query(QazWsx).first()
450-
assert qaz_wsx.x == 'test'
451-
session.delete(qaz_wsx) # issues a DELETE.
452-
assert session.query(QazWsx).first() is None
453421

454422

455423
class CommitOnTeardownTestCase(unittest.TestCase):
@@ -500,7 +468,6 @@ def suite():
500468
suite.addTest(unittest.makeSuite(CommitOnTeardownTestCase))
501469
if flask.signals_available:
502470
suite.addTest(unittest.makeSuite(SignallingTestCase))
503-
suite.addTest(unittest.makeSuite(StandardSessionTestCase))
504471
return suite
505472

506473

0 commit comments

Comments
 (0)