Skip to content

Commit 88892a8

Browse files
singingwolfboydavidism
authored andcommitted
Support binds on abstract models (pallets-eco#373)
1 parent 095f19a commit 88892a8

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

flask_sqlalchemy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,9 @@ def _join(match):
599599
return DeclarativeMeta.__new__(cls, name, bases, d)
600600

601601
def __init__(self, name, bases, d):
602-
bind_key = d.pop('__bind_key__', None)
602+
bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None)
603603
DeclarativeMeta.__init__(self, name, bases, d)
604-
if bind_key is not None:
604+
if bind_key is not None and hasattr(self, '__table__'):
605605
self.__table__.info['bind_key'] = bind_key
606606

607607

test_sqlalchemy.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import with_statement
22

33
import atexit
4+
import tempfile
5+
import os
46
import unittest
57
from datetime import datetime
68
import flask
@@ -384,12 +386,10 @@ def index():
384386
class BindsTestCase(unittest.TestCase):
385387

386388
def test_basic_binds(self):
387-
import tempfile
388389
_, db1 = tempfile.mkstemp()
389390
_, db2 = tempfile.mkstemp()
390391

391392
def _remove_files():
392-
import os
393393
try:
394394
os.remove(db1)
395395
os.remove(db2)
@@ -456,6 +456,44 @@ class Baz(db.Model):
456456
Baz.__table__: db.get_engine(app, None)
457457
})
458458

459+
def test_abstract_binds(self):
460+
_, db1 = tempfile.mkstemp()
461+
_, db2 = tempfile.mkstemp()
462+
463+
def _remove_files():
464+
try:
465+
os.remove(db1)
466+
os.remove(db2)
467+
except IOError:
468+
pass
469+
atexit.register(_remove_files)
470+
471+
app = flask.Flask(__name__)
472+
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
473+
app.config['SQLALCHEMY_BINDS'] = {
474+
'foo': 'sqlite:///' + db1,
475+
'bar': 'sqlite:///' + db2
476+
}
477+
db = sqlalchemy.SQLAlchemy(app)
478+
479+
class AbstractFooBoundModel(db.Model):
480+
__abstract__ = True
481+
__bind_key__ = 'foo'
482+
483+
class FooBoundModel(AbstractFooBoundModel):
484+
id = db.Column(db.Integer, primary_key=True)
485+
486+
db.create_all()
487+
488+
# does the model have the correct engines?
489+
self.assertEqual(db.metadata.tables['foo_bound_model'].info['bind_key'], 'foo')
490+
491+
# see the tables created in an engine
492+
metadata = db.MetaData()
493+
metadata.reflect(bind=db.get_engine(app, 'foo'))
494+
self.assertEqual(len(metadata.tables), 1)
495+
self.assertTrue('foo_bound_model' in metadata.tables)
496+
459497

460498
class DefaultQueryClassTestCase(unittest.TestCase):
461499

0 commit comments

Comments
 (0)