9
9
:license: BSD, see LICENSE for more details.
10
10
"""
11
11
from __future__ import with_statement , absolute_import
12
+ import re
12
13
import sys
13
14
import time
14
15
import sqlalchemy
15
16
from math import ceil
16
17
from types import MethodType
18
+ from functools import partial
17
19
from flask import _request_ctx_stack , abort
20
+ from flask .signals import Namespace
18
21
from operator import itemgetter
19
22
from threading import Lock
20
23
from sqlalchemy import orm
21
24
from sqlalchemy .orm .exc import UnmappedClassError
25
+ from sqlalchemy .orm .mapper import Mapper
26
+ from sqlalchemy .orm .session import Session
27
+ from sqlalchemy .orm .interfaces import MapperExtension , SessionExtension , \
28
+ EXT_CONTINUE
22
29
from sqlalchemy .interfaces import ConnectionProxy
23
30
from sqlalchemy .engine .url import make_url
24
31
from sqlalchemy .ext .declarative import declarative_base
32
+ from sqlalchemy .util import to_list
25
33
26
34
# the best timer function for the platform
27
35
if sys .platform == 'win32' :
30
38
_timer = time .time
31
39
32
40
41
+ _camelcase_re = re .compile (r'([A-Z]+)(?=[a-z0-9])' )
42
+ _signals = Namespace ()
43
+
44
+
45
+ models_committed = _signals .signal ('models-committed' )
46
+ before_models_committed = _signals .signal ('before-models-committed' )
47
+
48
+
33
49
def _create_scoped_session (db ):
34
- return orm .scoped_session (lambda : orm .create_session (autocommit = False ,
35
- autoflush = False ,
36
- bind = db .engine ))
50
+ return orm .scoped_session (partial (_SignallingSession , db ))
37
51
38
52
39
53
def _include_sqlalchemy (obj ):
@@ -100,6 +114,62 @@ def cursor_execute(self, execute, cursor, statement, parameters,
100
114
_calling_context (self .app_package ))))
101
115
102
116
117
+ class _SignalTrackingMapperExtension (MapperExtension ):
118
+
119
+ def after_delete (self , mapper , connection , instance ):
120
+ return self ._record (mapper , instance , 'delete' )
121
+
122
+ def after_insert (self , mapper , connection , instance ):
123
+ return self ._record (mapper , instance , 'insert' )
124
+
125
+ def after_update (self , mapper , connection , instance ):
126
+ return self ._record (mapper , instance , 'update' )
127
+
128
+ def _record (self , mapper , model , operation ):
129
+ pk = tuple (mapper .primary_key_from_instance (model ))
130
+ orm .object_session (model )._model_changes [pk ] = (model , operation )
131
+ return EXT_CONTINUE
132
+
133
+
134
+ class _SignalTrackingMapper (Mapper ):
135
+
136
+ def __init__ (self , * args , ** kwargs ):
137
+ extensions = to_list (kwargs .pop ('extension' , None ), [])
138
+ extensions .append (_SignalTrackingMapperExtension ())
139
+ kwargs ['extension' ] = extensions
140
+ Mapper .__init__ (self , * args , ** kwargs )
141
+
142
+
143
+ class _SignallingSessionExtension (SessionExtension ):
144
+
145
+ def before_commit (self , session ):
146
+ d = session ._model_changes
147
+ if d :
148
+ before_models_committed .send (session .app , changes = d .values ())
149
+ return EXT_CONTINUE
150
+
151
+ def after_commit (self , session ):
152
+ d = session ._model_changes
153
+ if d :
154
+ models_committed .send (session .app , changes = d .values ())
155
+ d .clear ()
156
+ return EXT_CONTINUE
157
+
158
+ def after_rollback (self , session ):
159
+ session ._model_changes .clear ()
160
+ return EXT_CONTINUE
161
+
162
+
163
+ class _SignallingSession (Session ):
164
+
165
+ def __init__ (self , db ):
166
+ Session .__init__ (self , autocommit = False , autoflush = False ,
167
+ extension = [_SignallingSessionExtension ()],
168
+ bind = db .engine )
169
+ self .app = db .app or _request_ctx_stack .top .app
170
+ self ._model_changes = {}
171
+
172
+
103
173
def get_debug_queries ():
104
174
"""In debug mode Flask-SQLAlchemy will log all the SQL queries sent
105
175
to the database. This information is available until the end of request
@@ -323,6 +393,21 @@ def get_engine(self):
323
393
return rv
324
394
325
395
396
+ class _ModelTableNameDescriptor (object ):
397
+
398
+ def __get__ (self , obj , type ):
399
+ tablename = type .__dict__ .get ('__tablename__' )
400
+ if not tablename :
401
+ def _join (match ):
402
+ word = match .group ()
403
+ if len (word ) > 1 :
404
+ return ('_%s_%s' % (word [:- 1 ], word [- 1 ])).lower ()
405
+ return '_' + word .lower ()
406
+ tablename = _camelcase_re .sub (_join , type .__name__ ).lstrip ('_' )
407
+ setattr (type , '__tablename__' , tablename )
408
+ return tablename
409
+
410
+
326
411
class Model (object ):
327
412
"""Baseclass for custom user models."""
328
413
@@ -334,6 +419,11 @@ class Model(object):
334
419
#: database for instances of this model.
335
420
query = None
336
421
422
+ #: arguments for the mapper
423
+ __mapper_cls__ = _SignalTrackingMapper
424
+
425
+ __tablename__ = _ModelTableNameDescriptor ()
426
+
337
427
338
428
class SQLAlchemy (object ):
339
429
"""This class is used to control the SQLAlchemy integration to one
0 commit comments