Skip to content

Commit 4855356

Browse files
committed
Added signal support and automatic table name setting
1 parent 4cee6f5 commit 4855356

File tree

2 files changed

+170
-20
lines changed

2 files changed

+170
-20
lines changed

flaskext/sqlalchemy.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,27 @@
99
:license: BSD, see LICENSE for more details.
1010
"""
1111
from __future__ import with_statement, absolute_import
12+
import re
1213
import sys
1314
import time
1415
import sqlalchemy
1516
from math import ceil
1617
from types import MethodType
18+
from functools import partial
1719
from flask import _request_ctx_stack, abort
20+
from flask.signals import Namespace
1821
from operator import itemgetter
1922
from threading import Lock
2023
from sqlalchemy import orm
2124
from sqlalchemy.orm.exc import UnmappedClassError
25+
from sqlalchemy.orm.mapper import Mapper
26+
from sqlalchemy.orm.session import Session
27+
from sqlalchemy.orm.interfaces import MapperExtension, SessionExtension, \
28+
EXT_CONTINUE
2229
from sqlalchemy.interfaces import ConnectionProxy
2330
from sqlalchemy.engine.url import make_url
2431
from sqlalchemy.ext.declarative import declarative_base
32+
from sqlalchemy.util import to_list
2533

2634
# the best timer function for the platform
2735
if sys.platform == 'win32':
@@ -30,10 +38,16 @@
3038
_timer = time.time
3139

3240

41+
_camelcase_re = re.compile(r'([A-Z]+)(?=[a-z0-9])')
42+
_signals = Namespace()
43+
44+
45+
models_committed = _signals.signal('models-committed')
46+
before_models_committed = _signals.signal('before-models-committed')
47+
48+
3349
def _create_scoped_session(db):
34-
return orm.scoped_session(lambda: orm.create_session(autocommit=False,
35-
autoflush=False,
36-
bind=db.engine))
50+
return orm.scoped_session(partial(_SignallingSession, db))
3751

3852

3953
def _include_sqlalchemy(obj):
@@ -100,6 +114,62 @@ def cursor_execute(self, execute, cursor, statement, parameters,
100114
_calling_context(self.app_package))))
101115

102116

117+
class _SignalTrackingMapperExtension(MapperExtension):
118+
119+
def after_delete(self, mapper, connection, instance):
120+
return self._record(mapper, instance, 'delete')
121+
122+
def after_insert(self, mapper, connection, instance):
123+
return self._record(mapper, instance, 'insert')
124+
125+
def after_update(self, mapper, connection, instance):
126+
return self._record(mapper, instance, 'update')
127+
128+
def _record(self, mapper, model, operation):
129+
pk = tuple(mapper.primary_key_from_instance(model))
130+
orm.object_session(model)._model_changes[pk] = (model, operation)
131+
return EXT_CONTINUE
132+
133+
134+
class _SignalTrackingMapper(Mapper):
135+
136+
def __init__(self, *args, **kwargs):
137+
extensions = to_list(kwargs.pop('extension', None), [])
138+
extensions.append(_SignalTrackingMapperExtension())
139+
kwargs['extension'] = extensions
140+
Mapper.__init__(self, *args, **kwargs)
141+
142+
143+
class _SignallingSessionExtension(SessionExtension):
144+
145+
def before_commit(self, session):
146+
d = session._model_changes
147+
if d:
148+
before_models_committed.send(session.app, changes=d.values())
149+
return EXT_CONTINUE
150+
151+
def after_commit(self, session):
152+
d = session._model_changes
153+
if d:
154+
models_committed.send(session.app, changes=d.values())
155+
d.clear()
156+
return EXT_CONTINUE
157+
158+
def after_rollback(self, session):
159+
session._model_changes.clear()
160+
return EXT_CONTINUE
161+
162+
163+
class _SignallingSession(Session):
164+
165+
def __init__(self, db):
166+
Session.__init__(self, autocommit=False, autoflush=False,
167+
extension=[_SignallingSessionExtension()],
168+
bind=db.engine)
169+
self.app = db.app or _request_ctx_stack.top.app
170+
self._model_changes = {}
171+
172+
103173
def get_debug_queries():
104174
"""In debug mode Flask-SQLAlchemy will log all the SQL queries sent
105175
to the database. This information is available until the end of request
@@ -323,6 +393,21 @@ def get_engine(self):
323393
return rv
324394

325395

396+
class _ModelTableNameDescriptor(object):
397+
398+
def __get__(self, obj, type):
399+
tablename = type.__dict__.get('__tablename__')
400+
if not tablename:
401+
def _join(match):
402+
word = match.group()
403+
if len(word) > 1:
404+
return ('_%s_%s' % (word[:-1], word[-1])).lower()
405+
return '_' + word.lower()
406+
tablename = _camelcase_re.sub(_join, type.__name__).lstrip('_')
407+
setattr(type, '__tablename__', tablename)
408+
return tablename
409+
410+
326411
class Model(object):
327412
"""Baseclass for custom user models."""
328413

@@ -334,6 +419,11 @@ class Model(object):
334419
#: database for instances of this model.
335420
query = None
336421

422+
#: arguments for the mapper
423+
__mapper_cls__ = _SignalTrackingMapper
424+
425+
__tablename__ = _ModelTableNameDescriptor()
426+
337427

338428
class SQLAlchemy(object):
339429
"""This class is used to control the SQLAlchemy integration to one

test_sqlalchemy.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,40 @@
66
from flaskext import sqlalchemy
77

88

9+
def make_todo_model(db):
10+
class Todo(db.Model):
11+
__tablename__ = 'todos'
12+
id = db.Column('todo_id', db.Integer, primary_key=True)
13+
title = db.Column(db.String(60))
14+
text = db.Column(db.String)
15+
done = db.Column(db.Boolean)
16+
pub_date = db.Column(db.DateTime)
17+
18+
def __init__(self, title, text):
19+
self.title = title
20+
self.text = text
21+
self.done = False
22+
self.pub_date = datetime.utcnow()
23+
return Todo
24+
25+
926
class BasicAppTestCase(unittest.TestCase):
1027

1128
def setUp(self):
1229
app = flask.Flask(__name__)
1330
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
1431
app.config['TESTING'] = True
1532
db = sqlalchemy.SQLAlchemy(app)
16-
17-
class Todo(db.Model):
18-
__tablename__ = 'todos'
19-
id = db.Column('todo_id', db.Integer, primary_key=True)
20-
title = db.Column(db.String(60))
21-
text = db.Column(db.String)
22-
done = db.Column(db.Boolean)
23-
pub_date = db.Column(db.DateTime)
24-
25-
def __init__(self, title, text):
26-
self.title = title
27-
self.text = text
28-
self.done = False
29-
self.pub_date = datetime.utcnow()
33+
self.Todo = make_todo_model(db)
3034

3135
@app.route('/')
3236
def index():
33-
return '\n'.join(x.title for x in Todo.query.all())
37+
return '\n'.join(x.title for x in self.Todo.query.all())
3438

3539
@app.route('/add', methods=['POST'])
3640
def add():
3741
form = flask.request.form
38-
todo = Todo(form['title'], form['text'])
42+
todo = self.Todo(form['title'], form['text'])
3943
db.session.add(todo)
4044
db.session.commit()
4145
return 'added'
@@ -44,7 +48,6 @@ def add():
4448

4549
self.app = app
4650
self.db = db
47-
self.Todo = Todo
4851

4952
def tearDown(self):
5053
self.db.drop_all()
@@ -83,6 +86,63 @@ def test_helper_api(self):
8386
self.assertEqual(self.db.metadata, self.db.Model.metadata)
8487

8588

89+
class SignallingTestCase(unittest.TestCase):
90+
91+
def setUp(self):
92+
self.app = app = flask.Flask(__name__)
93+
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
94+
app.config['TESTING'] = True
95+
self.db = sqlalchemy.SQLAlchemy(app)
96+
self.Todo = make_todo_model(self.db)
97+
self.db.create_all()
98+
99+
def tearDown(self):
100+
self.db.drop_all()
101+
102+
def test_model_signals(self):
103+
recorded = []
104+
def committed(sender, changes):
105+
self.assert_(isinstance(changes, list))
106+
recorded.extend(changes)
107+
with sqlalchemy.models_committed.connected_to(committed,
108+
sender=self.app):
109+
todo = self.Todo('Awesome', 'the text')
110+
self.db.session.add(todo)
111+
self.assertEqual(len(recorded), 0)
112+
self.db.session.commit()
113+
self.assertEqual(len(recorded), 1)
114+
self.assertEqual(recorded[0][0], todo)
115+
self.assertEqual(recorded[0][1], 'insert')
116+
del recorded[:]
117+
todo.text = 'aha'
118+
self.db.session.commit()
119+
self.assertEqual(len(recorded), 1)
120+
self.assertEqual(recorded[0][0], todo)
121+
self.assertEqual(recorded[0][1], 'update')
122+
del recorded[:]
123+
self.db.session.delete(todo)
124+
self.db.session.commit()
125+
self.assertEqual(len(recorded), 1)
126+
self.assertEqual(recorded[0][0], todo)
127+
self.assertEqual(recorded[0][1], 'delete')
128+
129+
130+
class HelperTestCase(unittest.TestCase):
131+
132+
def test_default_table_name(self):
133+
app = flask.Flask(__name__)
134+
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
135+
db = sqlalchemy.SQLAlchemy(app)
136+
137+
class FOOBar(db.Model):
138+
id = db.Column(db.Integer, primary_key=True)
139+
class BazBar(db.Model):
140+
id = db.Column(db.Integer, primary_key=True)
141+
142+
self.assertEqual(FOOBar.__tablename__, 'foo_bar')
143+
self.assertEqual(BazBar.__tablename__, 'baz_bar')
144+
145+
86146
class PaginationTestCase(unittest.TestCase):
87147

88148
def test_basic_pagination(self):

0 commit comments

Comments
 (0)