Skip to content

Commit aa31854

Browse files
committed
Merge branch '3.0.x'
2 parents e81bcc5 + b206d0a commit aa31854

16 files changed

+172
-76
lines changed

CHANGES.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ Unreleased
77
- Pass extra keyword arguments from ``get_or_404`` to ``session.get``. :issue:`1149`
88

99

10+
Version 3.0.4
11+
-------------
12+
13+
Released 2023-06-19
14+
15+
- Fix type hint for ``get_or_404`` return value. :pr:`1208`
16+
- Fix type hints for pyright (used by VS Code Pylance extension). :issue:`1205`
17+
18+
1019
Version 3.0.3
1120
-------------
1221

src/flask_sqlalchemy/extension.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from weakref import WeakKeyDictionary
66

77
import sqlalchemy as sa
8-
import sqlalchemy.event
9-
import sqlalchemy.exc
10-
import sqlalchemy.orm
8+
import sqlalchemy.event as sa_event
9+
import sqlalchemy.exc as sa_exc
10+
import sqlalchemy.orm as sa_orm
1111
from flask import abort
1212
from flask import current_app
1313
from flask import Flask
@@ -23,6 +23,8 @@
2323
from .session import Session
2424
from .table import _Table
2525

26+
_O = t.TypeVar("_O", bound=object) # Based on sqlalchemy.orm._typing.py
27+
2628

2729
class SQLAlchemy:
2830
"""Integrates SQLAlchemy with Flask. This handles setting up one or more engines,
@@ -122,7 +124,7 @@ def __init__(
122124
metadata: sa.MetaData | None = None,
123125
session_options: dict[str, t.Any] | None = None,
124126
query_class: type[Query] = Query,
125-
model_class: type[Model] | sa.orm.DeclarativeMeta = Model,
127+
model_class: type[Model] | sa_orm.DeclarativeMeta = Model,
126128
engine_options: dict[str, t.Any] | None = None,
127129
add_models_to_shell: bool = True,
128130
):
@@ -322,7 +324,7 @@ def init_app(self, app: Flask) -> None:
322324

323325
def _make_scoped_session(
324326
self, options: dict[str, t.Any]
325-
) -> sa.orm.scoped_session[Session]:
327+
) -> sa_orm.scoped_session[Session]:
326328
"""Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory
327329
from :meth:`_make_session_factory`. The result is available as :attr:`session`.
328330
@@ -345,11 +347,11 @@ def _make_scoped_session(
345347
"""
346348
scope = options.pop("scopefunc", _app_ctx_id)
347349
factory = self._make_session_factory(options)
348-
return sa.orm.scoped_session(factory, scope)
350+
return sa_orm.scoped_session(factory, scope)
349351

350352
def _make_session_factory(
351353
self, options: dict[str, t.Any]
352-
) -> sa.orm.sessionmaker[Session]:
354+
) -> sa_orm.sessionmaker[Session]:
353355
"""Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by
354356
:meth:`_make_scoped_session`.
355357
@@ -372,7 +374,7 @@ def _make_session_factory(
372374
"""
373375
options.setdefault("class_", Session)
374376
options.setdefault("query_cls", self.Query)
375-
return sa.orm.sessionmaker(db=self, **options)
377+
return sa_orm.sessionmaker(db=self, **options)
376378

377379
def _teardown_session(self, exc: BaseException | None) -> None:
378380
"""Remove the current session at the end of the request.
@@ -437,7 +439,7 @@ def __new__(
437439
return Table
438440

439441
def _make_declarative_base(
440-
self, model: type[Model] | sa.orm.DeclarativeMeta
442+
self, model: type[Model] | sa_orm.DeclarativeMeta
441443
) -> type[t.Any]:
442444
"""Create a SQLAlchemy declarative model class. The result is available as
443445
:attr:`Model`.
@@ -458,9 +460,9 @@ def _make_declarative_base(
458460
.. versionchanged:: 2.3
459461
``model`` can be an already created declarative model class.
460462
"""
461-
if not isinstance(model, sa.orm.DeclarativeMeta):
463+
if not isinstance(model, sa_orm.DeclarativeMeta):
462464
metadata = self._make_metadata(None)
463-
model = sa.orm.declarative_base(
465+
model = sa_orm.declarative_base(
464466
metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta
465467
)
466468

@@ -614,12 +616,12 @@ def engine(self) -> sa.engine.Engine:
614616

615617
def get_or_404(
616618
self,
617-
entity: type[t.Any],
619+
entity: type[_O],
618620
ident: t.Any,
619621
*,
620622
description: str | None = None,
621623
**kwargs: t.Any,
622-
) -> t.Any:
624+
) -> t.Optional[_O]:
623625
"""Like :meth:`session.get() <sqlalchemy.orm.Session.get>` but aborts with a
624626
``404 Not Found`` error instead of returning ``None``.
625627
@@ -672,7 +674,7 @@ def one_or_404(
672674
"""
673675
try:
674676
return self.session.execute(statement).scalar_one()
675-
except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound):
677+
except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound):
676678
abort(404, description=description)
677679

678680
def paginate(
@@ -751,7 +753,7 @@ def _call_for_binds(
751753
if key is None:
752754
message = f"'SQLALCHEMY_DATABASE_URI' config is not set. {message}"
753755

754-
raise sa.exc.UnboundExecutionError(message) from None
756+
raise sa_exc.UnboundExecutionError(message) from None
755757

756758
metadata = self.metadatas[key]
757759
getattr(metadata, op_name)(bind=engine)
@@ -828,31 +830,31 @@ def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None:
828830

829831
def relationship(
830832
self, *args: t.Any, **kwargs: t.Any
831-
) -> sa.orm.RelationshipProperty[t.Any]:
833+
) -> sa_orm.RelationshipProperty[t.Any]:
832834
"""A :func:`sqlalchemy.orm.relationship` that applies this extension's
833835
:attr:`Query` class for dynamic relationships and backrefs.
834836
835837
.. versionchanged:: 3.0
836838
The :attr:`Query` class is set on ``backref``.
837839
"""
838840
self._set_rel_query(kwargs)
839-
return sa.orm.relationship(*args, **kwargs)
841+
return sa_orm.relationship(*args, **kwargs)
840842

841843
def dynamic_loader(
842844
self, argument: t.Any, **kwargs: t.Any
843-
) -> sa.orm.RelationshipProperty[t.Any]:
845+
) -> sa_orm.RelationshipProperty[t.Any]:
844846
"""A :func:`sqlalchemy.orm.dynamic_loader` that applies this extension's
845847
:attr:`Query` class for relationships and backrefs.
846848
847849
.. versionchanged:: 3.0
848850
The :attr:`Query` class is set on ``backref``.
849851
"""
850852
self._set_rel_query(kwargs)
851-
return sa.orm.dynamic_loader(argument, **kwargs)
853+
return sa_orm.dynamic_loader(argument, **kwargs)
852854

853855
def _relation(
854856
self, *args: t.Any, **kwargs: t.Any
855-
) -> sa.orm.RelationshipProperty[t.Any]:
857+
) -> sa_orm.RelationshipProperty[t.Any]:
856858
"""A :func:`sqlalchemy.orm.relationship` that applies this extension's
857859
:attr:`Query` class for dynamic relationships and backrefs.
858860
@@ -864,20 +866,20 @@ def _relation(
864866
The :attr:`Query` class is set on ``backref``.
865867
"""
866868
self._set_rel_query(kwargs)
867-
f = sa.orm.relation # type: ignore[attr-defined]
868-
return f(*args, **kwargs) # type: ignore[no-any-return]
869+
f = sa_orm.relationship
870+
return f(*args, **kwargs)
869871

870872
def __getattr__(self, name: str) -> t.Any:
871873
if name == "relation":
872874
return self._relation
873875

874876
if name == "event":
875-
return sa.event
877+
return sa_event
876878

877879
if name.startswith("_"):
878880
raise AttributeError(name)
879881

880-
for mod in (sa, sa.orm):
882+
for mod in (sa, sa_orm):
881883
if hasattr(mod, name):
882884
return getattr(mod, name)
883885

src/flask_sqlalchemy/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing as t
55

66
import sqlalchemy as sa
7-
import sqlalchemy.orm
7+
import sqlalchemy.orm as sa_orm
88

99
from .query import Query
1010

@@ -174,21 +174,21 @@ def should_set_tablename(cls: type) -> bool:
174174
joined-table inheritance. If no primary key is found, the name will be unset.
175175
"""
176176
if cls.__dict__.get("__abstract__", False) or not any(
177-
isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]
177+
isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:]
178178
):
179179
return False
180180

181181
for base in cls.__mro__:
182182
if "__tablename__" not in base.__dict__:
183183
continue
184184

185-
if isinstance(base.__dict__["__tablename__"], sa.orm.declared_attr):
185+
if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr):
186186
return False
187187

188188
return not (
189189
base is cls
190190
or base.__dict__.get("__abstract__", False)
191-
or not isinstance(base, sa.orm.DeclarativeMeta)
191+
or not isinstance(base, sa_orm.DeclarativeMeta)
192192
)
193193

194194
return True
@@ -200,7 +200,7 @@ def camel_to_snake_case(name: str) -> str:
200200
return name.lower().lstrip("_")
201201

202202

203-
class DefaultMeta(BindMetaMixin, NameMetaMixin, sa.orm.DeclarativeMeta):
203+
class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta):
204204
"""SQLAlchemy declarative metaclass that provides ``__bind_key__`` and
205205
``__tablename__`` support.
206206
"""

src/flask_sqlalchemy/pagination.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from math import ceil
55

66
import sqlalchemy as sa
7-
import sqlalchemy.orm
7+
import sqlalchemy.orm as sa_orm
88
from flask import abort
99
from flask import request
1010

@@ -336,7 +336,7 @@ def _query_items(self) -> list[t.Any]:
336336

337337
def _query_count(self) -> int:
338338
select = self._query_args["select"]
339-
sub = select.options(sa.orm.lazyload("*")).order_by(None).subquery()
339+
sub = select.options(sa_orm.lazyload("*")).order_by(None).subquery()
340340
session = self._query_args["session"]
341341
out = session.execute(sa.select(sa.func.count()).select_from(sub)).scalar()
342342
return out # type: ignore[no-any-return]

src/flask_sqlalchemy/query.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
import typing as t
44

5-
import sqlalchemy as sa
6-
import sqlalchemy.exc
7-
import sqlalchemy.orm
5+
import sqlalchemy.exc as sa_exc
6+
import sqlalchemy.orm as sa_orm
87
from flask import abort
98

109
from .pagination import Pagination
1110
from .pagination import QueryPagination
1211

1312

14-
class Query(sa.orm.Query): # type: ignore[type-arg]
13+
class Query(sa_orm.Query): # type: ignore[type-arg]
1514
"""SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods
1615
useful for querying in a web application.
1716
@@ -58,7 +57,7 @@ def one_or_404(self, description: str | None = None) -> t.Any:
5857
"""
5958
try:
6059
return self.one()
61-
except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound):
60+
except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound):
6261
abort(404, description=description)
6362

6463
def paginate(

src/flask_sqlalchemy/record_queries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from time import perf_counter
77

88
import sqlalchemy as sa
9-
import sqlalchemy.event
9+
import sqlalchemy.event as sa_event
1010
from flask import current_app
1111
from flask import g
1212
from flask import has_app_context
@@ -72,8 +72,8 @@ def duration(self) -> float:
7272

7373

7474
def _listen(engine: sa.engine.Engine) -> None:
75-
sa.event.listen(engine, "before_cursor_execute", _record_start, named=True)
76-
sa.event.listen(engine, "after_cursor_execute", _record_end, named=True)
75+
sa_event.listen(engine, "before_cursor_execute", _record_start, named=True)
76+
sa_event.listen(engine, "after_cursor_execute", _record_end, named=True)
7777

7878

7979
def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None:

src/flask_sqlalchemy/session.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import typing as t
44

55
import sqlalchemy as sa
6-
import sqlalchemy.exc
7-
import sqlalchemy.orm
6+
import sqlalchemy.exc as sa_exc
7+
import sqlalchemy.orm as sa_orm
88
from flask.globals import app_ctx
99

1010
if t.TYPE_CHECKING:
1111
from .extension import SQLAlchemy
1212

1313

14-
class Session(sa.orm.Session):
14+
class Session(sa_orm.Session):
1515
"""A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to
1616
use based on the bind key associated with the metadata associated with the thing
1717
being queried.
@@ -55,9 +55,9 @@ def get_bind(
5555
if mapper is not None:
5656
try:
5757
mapper = sa.inspect(mapper)
58-
except sa.exc.NoInspectionAvailable as e:
58+
except sa_exc.NoInspectionAvailable as e:
5959
if isinstance(mapper, type):
60-
raise sa.orm.exc.UnmappedClassError(mapper) from e
60+
raise sa_orm.exc.UnmappedClassError(mapper) from e
6161

6262
raise
6363

@@ -88,7 +88,7 @@ def _clause_to_engine(
8888
key = clause.metadata.info["bind_key"]
8989

9090
if key not in engines:
91-
raise sa.exc.UnboundExecutionError(
91+
raise sa_exc.UnboundExecutionError(
9292
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
9393
)
9494

src/flask_sqlalchemy/track_modifications.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import typing as t
44

55
import sqlalchemy as sa
6-
import sqlalchemy.event
7-
import sqlalchemy.orm
6+
import sqlalchemy.event as sa_event
7+
import sqlalchemy.orm as sa_orm
88
from flask import current_app
99
from flask import has_app_context
1010
from flask.signals import Namespace # type: ignore[attr-defined]
@@ -29,12 +29,12 @@
2929
"""
3030

3131

32-
def _listen(session: sa.orm.scoped_session[Session]) -> None:
33-
sa.event.listen(session, "before_flush", _record_ops, named=True)
34-
sa.event.listen(session, "before_commit", _record_ops, named=True)
35-
sa.event.listen(session, "before_commit", _before_commit)
36-
sa.event.listen(session, "after_commit", _after_commit)
37-
sa.event.listen(session, "after_rollback", _after_rollback)
32+
def _listen(session: sa_orm.scoped_session[Session]) -> None:
33+
sa_event.listen(session, "before_flush", _record_ops, named=True)
34+
sa_event.listen(session, "before_commit", _record_ops, named=True)
35+
sa_event.listen(session, "before_commit", _before_commit)
36+
sa_event.listen(session, "after_commit", _after_commit)
37+
sa_event.listen(session, "after_rollback", _after_rollback)
3838

3939

4040
def _record_ops(session: Session, **kwargs: t.Any) -> None:

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def db(app: Flask) -> SQLAlchemy:
3131

3232

3333
@pytest.fixture
34-
def Todo(app: Flask, db: SQLAlchemy) -> t.Any:
34+
def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]:
3535
class Todo(db.Model):
3636
id = sa.Column(sa.Integer, primary_key=True)
3737
title = sa.Column(sa.String)

0 commit comments

Comments
 (0)