Skip to content

Commit bd6ce9e

Browse files
Thadeus Burgessmitsuhiko
authored andcommitted
Refactored session_extensions. Now declared in SQLAlchemy __init__
1 parent eb82f71 commit bd6ce9e

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

flaskext/sqlalchemy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,7 @@ def after_rollback(self, session):
162162

163163
class _SignallingSession(Session):
164164

165-
def __init__(self, db):
166-
if isinstance(db.session_extensions, list):
167-
db.session_extensions.append(_SignallingSessionExtension())
168-
else:
169-
db.session_extensions = [db.session_extensions, _SignallingSessionExtension()]
170-
165+
def __init__(self, db):
171166
Session.__init__(self, autocommit=False, autoflush=False,
172167
extension=db.session_extensions,
173168
bind=db.engine)
@@ -474,10 +469,15 @@ class User(db.Model):
474469
pw_hash = db.Column(db.String(80))
475470
"""
476471

477-
def __init__(self, app=None, use_native_unicode=True, session_extensions=[]):
472+
def __init__(self, app=None, use_native_unicode=True, session_extensions=None):
478473
self.use_native_unicode = use_native_unicode
479474
self.session_extensions = session_extensions
480475

476+
if self.session_extensions:
477+
self.session_extensions = to_list(self.session_extensions) + [_SignallingSessionExtension()]
478+
else:
479+
self.session_extensions = [_SignallingSessionExtension()]
480+
481481
self.session = _create_scoped_session(self)
482482

483483
self.Model = declarative_base(cls=Model, name='Model')

0 commit comments

Comments
 (0)