Skip to content

Commit d89c215

Browse files
committed
Possible way to mixin with DeclarativeBase
1 parent 41dccc7 commit d89c215

File tree

4 files changed

+151
-3
lines changed

4 files changed

+151
-3
lines changed

.pdm-python

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/Users/pamelafox/flask-sqlalchemy/.venv/bin/python

src/flask_sqlalchemy/extension.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,11 @@ def _make_declarative_base(
503503
.. versionchanged:: 2.3
504504
``model`` can be an already created declarative model class.
505505
"""
506-
if not isinstance(model, sa.orm.DeclarativeMeta):
506+
if (
507+
not isinstance(model, sa.orm.DeclarativeMeta)
508+
and not issubclass(model, sa.orm.DeclarativeBase)
509+
and not issubclass(model, sa.orm.DeclarativeBaseNoMeta)
510+
):
507511
metadata = self._make_metadata(None)
508512
model = sa.orm.declarative_base(
509513
metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta

src/flask_sqlalchemy/model.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def should_set_tablename(cls: type) -> bool:
182182
joined-table inheritance. If no primary key is found, the name will be unset.
183183
"""
184184
if cls.__dict__.get("__abstract__", False) or not any(
185-
isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]
185+
(isinstance(b, sa.orm.DeclarativeMeta) or b is sa.orm.DeclarativeBase)
186+
for b in cls.__mro__[1:]
186187
):
187188
return False
188189

@@ -196,7 +197,10 @@ def should_set_tablename(cls: type) -> bool:
196197
return not (
197198
base is cls
198199
or base.__dict__.get("__abstract__", False)
199-
or not isinstance(base, sa.orm.DeclarativeMeta)
200+
or not (
201+
isinstance(base, sa.orm.DeclarativeMeta)
202+
or base is sa.orm.DeclarativeBase
203+
)
200204
)
201205

202206
return True
@@ -212,3 +216,99 @@ class DefaultMeta(BindMetaMixin, NameMetaMixin, sa.orm.DeclarativeMeta):
212216
"""SQLAlchemy declarative metaclass that provides ``__bind_key__`` and
213217
``__tablename__`` support.
214218
"""
219+
220+
221+
class DefaultMixin:
222+
"""A mixin that provides Flask-SQLAlchemy default functionality:
223+
* sets a model's ``__tablename__`` by converting the
224+
``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models
225+
that do not otherwise define ``__tablename__``. If a model does not define a primary
226+
key, it will not generate a name or ``__table__``, for single-table inheritance.
227+
* sets a model's ``metadata`` based on its ``__bind_key__``.
228+
If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is
229+
ignored. If the ``metadata`` is the same as the parent model, it will not be set
230+
directly on the child model.
231+
* Provides a default repr based on the model's primary key.
232+
"""
233+
234+
__fsa__: SQLAlchemy
235+
metadata: sa.MetaData
236+
__tablename__: str
237+
__table__: sa.Table
238+
239+
def __init_subclass__(cls, **kwargs):
240+
if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__):
241+
bind_key = getattr(cls, "__bind_key__", None)
242+
parent_metadata = getattr(cls, "metadata", None)
243+
metadata = cls.__fsa__._make_metadata(bind_key)
244+
245+
if metadata is not parent_metadata:
246+
cls.metadata = metadata
247+
248+
if should_set_tablename(cls):
249+
cls.__tablename__ = camel_to_snake_case(cls.__name__)
250+
251+
super().__init_subclass__(**kwargs)
252+
253+
# __table_cls__ has run. If no table was created, use the parent table.
254+
if (
255+
"__tablename__" not in cls.__dict__
256+
and "__table__" in cls.__dict__
257+
and cls.__dict__["__table__"] is None
258+
):
259+
del cls.__table__
260+
261+
@classmethod
262+
def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None:
263+
"""This is called by SQLAlchemy during mapper setup. It determines the final
264+
table object that the model will use.
265+
266+
If no primary key is found, that indicates single-table inheritance, so no table
267+
will be created and ``__tablename__`` will be unset.
268+
"""
269+
schema = kwargs.get("schema")
270+
271+
if schema is None:
272+
key = args[0]
273+
else:
274+
key = f"{schema}.{args[0]}"
275+
276+
# Check if a table with this name already exists. Allows reflected tables to be
277+
# applied to models by name.
278+
if key in cls.metadata.tables:
279+
return sa.Table(*args, **kwargs)
280+
281+
# If a primary key is found, create a table for joined-table inheritance.
282+
for arg in args:
283+
if (isinstance(arg, sa.Column) and arg.primary_key) or isinstance(
284+
arg, sa.PrimaryKeyConstraint
285+
):
286+
return sa.Table(*args, **kwargs)
287+
288+
# If no base classes define a table, return one that's missing a primary key
289+
# so SQLAlchemy shows the correct error.
290+
for base in cls.__mro__[1:-1]:
291+
if "__table__" in base.__dict__:
292+
break
293+
else:
294+
return sa.Table(*args, **kwargs)
295+
296+
# Single-table inheritance, use the parent table name. __init__ will unset
297+
# __table__ based on this.
298+
if "__tablename__" in cls.__dict__:
299+
del cls.__tablename__
300+
301+
return None
302+
303+
def __repr__(self) -> str:
304+
state = sa.inspect(self)
305+
assert state is not None
306+
307+
if state.transient:
308+
pk = f"(transient {id(self)})"
309+
elif state.pending:
310+
pk = f"(pending {id(self)})"
311+
else:
312+
pk = ", ".join(map(str, state.identity))
313+
314+
return f"<{type(self).__name__} {pk}>"

tests/test_model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
import sqlalchemy as sa
77
import sqlalchemy.orm
88
from flask import Flask
9+
from sqlalchemy.orm import DeclarativeBase
10+
from sqlalchemy.orm import Mapped
11+
from sqlalchemy.orm import mapped_column
912

1013
from flask_sqlalchemy import SQLAlchemy
1114
from flask_sqlalchemy.model import DefaultMeta
15+
from flask_sqlalchemy.model import DefaultMixin
1216
from flask_sqlalchemy.model import Model
1317

1418

@@ -28,6 +32,45 @@ class CustomModel(Model):
2832
assert isinstance(db.Model, DefaultMeta)
2933

3034

35+
@pytest.mark.usefixtures("app_ctx")
36+
def test_custom_model_sqlalchemy20_class(app: Flask) -> None:
37+
from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
38+
39+
class Base(DeclarativeBase):
40+
pass
41+
42+
db = SQLAlchemy(app, model_class=Base)
43+
44+
# Check the model class is instantiated with the correct metaclass
45+
assert issubclass(db.Model, Base)
46+
assert isinstance(db.Model, type)
47+
assert isinstance(db.Model, DeclarativeAttributeIntercept)
48+
# Check that additional attributes are added to the model class
49+
assert db.Model.query_class is db.Query
50+
51+
# Now create a model that inherits from that declarative base
52+
class Quiz(DefaultMixin, db.Model):
53+
id: Mapped[int] = mapped_column(
54+
db.Integer, primary_key=True, autoincrement=True
55+
)
56+
title: Mapped[str] = mapped_column(db.String(255), nullable=False)
57+
58+
assert Quiz.__tablename__ == "quiz"
59+
assert isinstance(Quiz, DeclarativeAttributeIntercept)
60+
61+
db.create_all()
62+
quiz = Quiz(title="Python trivia")
63+
db.session.add(quiz)
64+
db.session.commit()
65+
66+
# Check column types are correct
67+
quiz_id: int = quiz.id
68+
quiz_title: str = quiz.title
69+
assert quiz_id == 1
70+
assert quiz_title == "Python trivia"
71+
assert repr(quiz) == f"<Quiz {quiz.id}>"
72+
73+
3174
@pytest.mark.usefixtures("app_ctx")
3275
@pytest.mark.parametrize("base", [Model, object])
3376
def test_custom_declarative_class(app: Flask, base: t.Any) -> None:

0 commit comments

Comments
 (0)