From e6dc0a7ca86fa17f9826b22b12337ac33c1892ff Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 21:20:29 +0300 Subject: [PATCH 01/14] task-local sessions --- .pre-commit-config.yaml | 2 +- fastapi_async_sqlalchemy/middleware.py | 108 +++++++++++++++++++------ requirements.txt | 6 +- 3 files changed, 89 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7723dcc..a21adc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: - --max-line-length=100 - --ignore=E203, E501, W503 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.982 + rev: v1.12.0 hooks: - id: mypy additional_dependencies: diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index bb6024d..e059b7c 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,27 +1,21 @@ +import asyncio from contextvars import ContextVar from typing import Dict, Optional, Union from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.types import ASGIApp from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError -try: - from sqlalchemy.ext.asyncio import async_sessionmaker -except ImportError: - from sqlalchemy.orm import sessionmaker as async_sessionmaker - def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None - # Usage of context vars inside closures is not recommended, since they are not properly - # garbage collected, but in our use case context var is created on program startup and - # is used throughout the whole its lifecycle. _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) + _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -29,8 +23,8 @@ def __init__( app: ASGIApp, db_url: Optional[Union[str, URL]] = None, custom_engine: Optional[Engine] = None, - engine_args: Dict = None, - session_args: Dict = None, + engine_args: Optional[Dict] = None, + session_args: Optional[Dict] = None, commit_on_exit: bool = False, ): super().__init__(app) @@ -57,28 +51,96 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): class DBSessionMeta(type): @property def session(self) -> AsyncSession: - """Return an instance of Session local to the current async context.""" + """Возвращает текущую сессию из контекста.""" if _Session is None: raise SessionNotInitialisedError - session = _session.get() - if session is None: + current_session = _session.get() + if current_session is None: raise MissingSessionError - return session + multi_sessions = _multi_sessions_ctx.get() + if multi_sessions: + # Если multi_sessions=True, используем Task-локальные сессии + task = asyncio.current_task() + if not hasattr(task, "_db_session"): + task._db_session = _Session() + + def cleanup(future): + session = getattr(task, "_db_session", None) + if session: + + async def do_cleanup(): + try: + if future.exception(): + await session.rollback() + else: + await session.commit() + finally: + await session.close() + + asyncio.create_task(do_cleanup()) + + task.add_done_callback(cleanup) + return task._db_session + else: + return current_session + + def __call__( + cls, + session_args: Optional[Dict] = None, + commit_on_exit: bool = False, + multi_sessions: bool = False, + ): + return cls._context_manager( + session_args=session_args, + commit_on_exit=commit_on_exit, + multi_sessions=multi_sessions, + ) + + def _context_manager( + cls, + session_args: Dict = None, + commit_on_exit: bool = False, + multi_sessions: bool = False, + ): + return DBSessionContextManager( + session_args=session_args, + commit_on_exit=commit_on_exit, + multi_sessions=multi_sessions, + ) class DBSession(metaclass=DBSessionMeta): - def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): - self.token = None + pass + + class DBSessionContextManager: + def __init__( + self, + session_args: Dict = None, + commit_on_exit: bool = False, + multi_sessions: bool = False, + ): self.session_args = session_args or {} self.commit_on_exit = commit_on_exit + self.multi_sessions = multi_sessions + self.token = None + self.multi_sessions_token = None + self._session = None async def __aenter__(self): - if not isinstance(_Session, async_sessionmaker): + if _Session is None: raise SessionNotInitialisedError - self.token = _session.set(_Session(**self.session_args)) # type: ignore - return type(self) + if self.multi_sessions: + self.multi_sessions_token = _multi_sessions_ctx.set(True) + + self._session = _Session(**self.session_args) + self.token = _session.set(self._session) + return self + + @property + def session(self): + return self._session async def __aexit__(self, exc_type, exc_value, traceback): session = _session.get() @@ -86,13 +148,13 @@ async def __aexit__(self, exc_type, exc_value, traceback): try: if exc_type is not None: await session.rollback() - elif ( - self.commit_on_exit - ): # Note: Changed this to elif to avoid commit after rollback + elif self.commit_on_exit: await session.commit() finally: await session.close() _session.reset(self.token) + if self.multi_sessions_token is not None: + _multi_sessions_ctx.reset(self.multi_sessions_token) return SQLAlchemyMiddleware, DBSession diff --git a/requirements.txt b/requirements.txt index 0279928..d3232d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,13 +17,13 @@ packaging>=22.0 pathspec>=0.9.0 pluggy==0.13.0 pycodestyle==2.5.0 -pydantic==1.10.13 +pydantic==1.10.18 pyflakes==2.1.1 pyparsing==2.4.2 pytest==7.2.0 pytest-cov==2.11.1 PyYAML>=5.4 -regex==2020.2.20 +regex>=2020.2.20 requests>=2.22.0 httpx>=0.20.0 six==1.12.0 @@ -39,4 +39,4 @@ wcwidth==0.1.7 zipp==3.1.0 black==24.4.2 pytest-asyncio==0.21.0 -greenlet==2.0.2 +greenlet==3.1.1 From 5af8f493f336e08672ac23f766b3928140c8261f Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 21:20:59 +0300 Subject: [PATCH 02/14] task-local sessions --- fastapi_async_sqlalchemy/middleware.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index e059b7c..229c5fa 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -51,7 +51,6 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): class DBSessionMeta(type): @property def session(self) -> AsyncSession: - """Возвращает текущую сессию из контекста.""" if _Session is None: raise SessionNotInitialisedError @@ -61,7 +60,6 @@ def session(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions: - # Если multi_sessions=True, используем Task-локальные сессии task = asyncio.current_task() if not hasattr(task, "_db_session"): task._db_session = _Session() From 7aa6decca560dd0abe41557b44aec8319212c7e0 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 21:28:16 +0300 Subject: [PATCH 03/14] commit_on_exit fix --- fastapi_async_sqlalchemy/middleware.py | 70 +++++++++----------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 229c5fa..195d0ac 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -11,11 +11,17 @@ from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError +try: + from sqlalchemy.ext.asyncio import async_sessionmaker +except ImportError: + from sqlalchemy.orm import sessionmaker as async_sessionmaker + def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) + _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -23,8 +29,8 @@ def __init__( app: ASGIApp, db_url: Optional[Union[str, URL]] = None, custom_engine: Optional[Engine] = None, - engine_args: Optional[Dict] = None, - session_args: Optional[Dict] = None, + engine_args: Dict = None, + session_args: Dict = None, commit_on_exit: bool = False, ): super().__init__(app) @@ -54,12 +60,9 @@ def session(self) -> AsyncSession: if _Session is None: raise SessionNotInitialisedError - current_session = _session.get() - if current_session is None: - raise MissingSessionError - multi_sessions = _multi_sessions_ctx.get() if multi_sessions: + commit_on_exit = _commit_on_exit_ctx.get() task = asyncio.current_task() if not hasattr(task, "_db_session"): task._db_session = _Session() @@ -73,7 +76,8 @@ async def do_cleanup(): if future.exception(): await session.rollback() else: - await session.commit() + if commit_on_exit: + await session.commit() finally: await session.close() @@ -82,67 +86,38 @@ async def do_cleanup(): task.add_done_callback(cleanup) return task._db_session else: - return current_session - - def __call__( - cls, - session_args: Optional[Dict] = None, - commit_on_exit: bool = False, - multi_sessions: bool = False, - ): - return cls._context_manager( - session_args=session_args, - commit_on_exit=commit_on_exit, - multi_sessions=multi_sessions, - ) - - def _context_manager( - cls, - session_args: Dict = None, - commit_on_exit: bool = False, - multi_sessions: bool = False, - ): - return DBSessionContextManager( - session_args=session_args, - commit_on_exit=commit_on_exit, - multi_sessions=multi_sessions, - ) + session = _session.get() + if session is None: + raise MissingSessionError + return session class DBSession(metaclass=DBSessionMeta): - pass - - class DBSessionContextManager: def __init__( self, session_args: Dict = None, commit_on_exit: bool = False, multi_sessions: bool = False, ): + self.token = None + self.multi_sessions_token = None + self.commit_on_exit_token = None self.session_args = session_args or {} self.commit_on_exit = commit_on_exit self.multi_sessions = multi_sessions - self.token = None - self.multi_sessions_token = None - self._session = None async def __aenter__(self): - if _Session is None: + if not isinstance(_Session, async_sessionmaker): raise SessionNotInitialisedError if self.multi_sessions: self.multi_sessions_token = _multi_sessions_ctx.set(True) + self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit) - self._session = _Session(**self.session_args) - self.token = _session.set(self._session) - return self - - @property - def session(self): - return self._session + self.token = _session.set(_Session(**self.session_args)) + return type(self) async def __aexit__(self, exc_type, exc_value, traceback): session = _session.get() - try: if exc_type is not None: await session.rollback() @@ -153,6 +128,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): _session.reset(self.token) if self.multi_sessions_token is not None: _multi_sessions_ctx.reset(self.multi_sessions_token) + _commit_on_exit_ctx.reset(self.commit_on_exit_token) return SQLAlchemyMiddleware, DBSession From d95a2b087bd904d52697dbace3a4a31b3209cf49 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:08:20 +0300 Subject: [PATCH 04/14] added test_multi_sessions --- tests/test_session.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_session.py b/tests/test_session.py index 06163ac..c7e8666 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,7 @@ +import asyncio + import pytest +from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware @@ -148,3 +151,24 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_ session_args = {"expire_on_commit": False} async with db(session_args=session_args): db.session + + +@pytest.mark.asyncio +async def test_multi_sessions(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True): + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + res = await asyncio.gather(*tasks) + assert len(res) == 6 From cb29f3e67f26e0d4d2a8c2b56df0aadf59866bb4 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:14:11 +0300 Subject: [PATCH 05/14] added test_multi_sessions --- README.md | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 84546be..70f1c6a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +from pytest_cov.embed import multiprocessing_finishfrom tests.test_session import db_url + # SQLAlchemy FastAPI middleware [![ci](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB) @@ -127,9 +129,10 @@ app.add_middleware( routes.py ```python +import asyncio + from fastapi import APIRouter -from sqlalchemy import column -from sqlalchemy import table +from sqlalchemy import column, table, text from databases import first_db, second_db @@ -147,4 +150,22 @@ async def get_files_from_first_db(): async def get_files_from_second_db(): result = await second_db.session.execute(foo.select()) return result.fetchall() + + +@router.get("/concurrent-queries") +async def parallel_select(): + async with first_db(multi_sessions=True): + async def execute_query(query): + return await first_db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) ``` From cb081f1e7b98f5758a13031116cb56d512582946 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:27:48 +0300 Subject: [PATCH 06/14] fixes mypy --- .pre-commit-config.yaml | 2 +- fastapi_async_sqlalchemy/middleware.py | 9 +++++---- tests/test_session.py | 1 + 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a21adc0..7723dcc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: - --max-line-length=100 - --ignore=E203, E501, W503 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.12.0 + rev: v0.982 hooks: - id: mypy additional_dependencies: diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 195d0ac..6893e28 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,4 +1,5 @@ import asyncio +from asyncio import Task from contextvars import ContextVar from typing import Dict, Optional, Union @@ -12,7 +13,7 @@ from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError try: - from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811 except ImportError: from sqlalchemy.orm import sessionmaker as async_sessionmaker @@ -63,9 +64,9 @@ def session(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions: commit_on_exit = _commit_on_exit_ctx.get() - task = asyncio.current_task() + task: Task = asyncio.current_task() # type: ignore if not hasattr(task, "_db_session"): - task._db_session = _Session() + task._db_session = _Session() # type: ignore def cleanup(future): session = getattr(task, "_db_session", None) @@ -84,7 +85,7 @@ async def do_cleanup(): asyncio.create_task(do_cleanup()) task.add_done_callback(cleanup) - return task._db_session + return task._db_session # type: ignore else: session = _session.get() if session is None: diff --git a/tests/test_session.py b/tests/test_session.py index c7e8666..82f5dc9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -158,6 +158,7 @@ async def test_multi_sessions(app, db, SQLAlchemyMiddleware): app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) async with db(multi_sessions=True): + async def execute_query(query): return await db.session.execute(text(query)) From 5c9f01fa0c2ae247c112ca98dc7a24fd350bccbe Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:28:41 +0300 Subject: [PATCH 07/14] fixes mypy --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 70f1c6a..6d1e722 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -from pytest_cov.embed import multiprocessing_finishfrom tests.test_session import db_url - # SQLAlchemy FastAPI middleware [![ci](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB) From aafe26b517e4a3e53428574b182ac8f27df276e6 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:39:11 +0300 Subject: [PATCH 08/14] fixes mypy --- fastapi_async_sqlalchemy/middleware.py | 27 ++++++++++++++++++++++++++ requirements.txt | 4 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 6893e28..f70a49e 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -23,6 +23,9 @@ def create_middleware_and_session_proxy(): _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) + # Usage of context vars inside closures is not recommended, since they are not properly + # garbage collected, but in our use case context var is created on program startup and + # is used throughout the whole its lifecycle. class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -58,11 +61,35 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): class DBSessionMeta(type): @property def session(self) -> AsyncSession: + """Return an instance of Session local to the current async context.""" if _Session is None: raise SessionNotInitialisedError multi_sessions = _multi_sessions_ctx.get() if multi_sessions: + """If multi_sessions is True, we are in a context where multiple sessions are allowed. + In this case, we need to create a new session for each task. + We also need to commit the session on exit if commit_on_exit is True. + This is useful when we need to run multiple queries in parallel. + For example, when we need to run multiple queries in parallel in a route handler. + Example: + ```python + async with db(multi_sessions=True): + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) + ``` + """ commit_on_exit = _commit_on_exit_ctx.get() task: Task = asyncio.current_task() # type: ignore if not hasattr(task, "_db_session"): diff --git a/requirements.txt b/requirements.txt index d3232d8..e3a0644 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ coverage>=5.2.1 entrypoints==0.3 fastapi==0.90.0 # pyup: ignore flake8==3.7.9 -idna==2.8 +idna==3.7 importlib-metadata==1.5.0 isort==4.3.21 mccabe==0.6.1 @@ -36,7 +36,7 @@ toml>=0.10.1 typed-ast>=1.4.2 urllib3>=1.25.9 wcwidth==0.1.7 -zipp==3.1.0 +zipp==3.19.1 black==24.4.2 pytest-asyncio==0.21.0 greenlet==3.1.1 From a657d8f2067f6e1b999bd8522ed1731ab9122505 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:57:27 +0300 Subject: [PATCH 09/14] fixes mypy --- fastapi_async_sqlalchemy/middleware.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index f70a49e..06f3268 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -67,8 +67,7 @@ def session(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions: - """If multi_sessions is True, we are in a context where multiple sessions are allowed. - In this case, we need to create a new session for each task. + """ In this case, we need to create a new session for each task. We also need to commit the session on exit if commit_on_exit is True. This is useful when we need to run multiple queries in parallel. For example, when we need to run multiple queries in parallel in a route handler. From 65ed17a3ce8f758b728d3e1cb00e7747815671cb Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:58:51 +0300 Subject: [PATCH 10/14] fixes mypy --- fastapi_async_sqlalchemy/middleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 06f3268..e3b1563 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -67,7 +67,7 @@ def session(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions: - """ In this case, we need to create a new session for each task. + """In this case, we need to create a new session for each task. We also need to commit the session on exit if commit_on_exit is True. This is useful when we need to run multiple queries in parallel. For example, when we need to run multiple queries in parallel in a route handler. From 83d7bef8b99c83493444aa24bfc8f64094186220 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 23:45:44 +0300 Subject: [PATCH 11/14] fixes mypy --- fastapi_async_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index ab2c3b7..21c6edc 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -2,4 +2,4 @@ __all__ = ["db", "SQLAlchemyMiddleware"] -__version__ = "0.6.1" +__version__ = "0.7.0.dev1" From 0c74aaf2fc9808374e5cc148a96a32f38f3315ac Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Fri, 18 Oct 2024 13:04:31 +0300 Subject: [PATCH 12/14] WIP: multi_sessions --- fastapi_async_sqlalchemy/__init__.py | 2 +- fastapi_async_sqlalchemy/middleware.py | 75 +++++++++++++------------- tests/test_session.py | 25 +++++++++ 3 files changed, 64 insertions(+), 38 deletions(-) diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 21c6edc..2821653 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -2,4 +2,4 @@ __all__ = ["db", "SQLAlchemyMiddleware"] -__version__ = "0.7.0.dev1" +__version__ = "0.7.0.dev2" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index e3b1563..13cb359 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,5 +1,4 @@ import asyncio -from asyncio import Task from contextvars import ContextVar from typing import Dict, Optional, Union @@ -22,6 +21,9 @@ def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) + _task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar( + "_task_session_ctx", default=None + ) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and @@ -90,28 +92,26 @@ async def execute_query(query): ``` """ commit_on_exit = _commit_on_exit_ctx.get() - task: Task = asyncio.current_task() # type: ignore - if not hasattr(task, "_db_session"): - task._db_session = _Session() # type: ignore - - def cleanup(future): - session = getattr(task, "_db_session", None) - if session: - - async def do_cleanup(): - try: - if future.exception(): - await session.rollback() - else: - if commit_on_exit: - await session.commit() - finally: - await session.close() - - asyncio.create_task(do_cleanup()) - - task.add_done_callback(cleanup) - return task._db_session # type: ignore + session = _task_session_ctx.get() + if session is None: + session = _Session() + _task_session_ctx.set(session) + + async def cleanup(): + try: + if commit_on_exit: + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + _task_session_ctx.set(None) + + task = asyncio.current_task() + if task is not None: + task.add_done_callback(lambda t: asyncio.create_task(cleanup())) + return session else: session = _session.get() if session is None: @@ -139,23 +139,24 @@ async def __aenter__(self): if self.multi_sessions: self.multi_sessions_token = _multi_sessions_ctx.set(True) self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit) - - self.token = _session.set(_Session(**self.session_args)) + else: + self.token = _session.set(_Session(**self.session_args)) return type(self) async def __aexit__(self, exc_type, exc_value, traceback): - session = _session.get() - try: - if exc_type is not None: - await session.rollback() - elif self.commit_on_exit: - await session.commit() - finally: - await session.close() - _session.reset(self.token) - if self.multi_sessions_token is not None: - _multi_sessions_ctx.reset(self.multi_sessions_token) - _commit_on_exit_ctx.reset(self.commit_on_exit_token) + if self.multi_sessions: + _multi_sessions_ctx.reset(self.multi_sessions_token) + _commit_on_exit_ctx.reset(self.commit_on_exit_token) + else: + session = _session.get() + try: + if exc_type is not None: + await session.rollback() + elif self.commit_on_exit: + await session.commit() + finally: + await session.close() + _session.reset(self.token) return SQLAlchemyMiddleware, DBSession diff --git a/tests/test_session.py b/tests/test_session.py index 82f5dc9..9400fea 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -173,3 +173,28 @@ async def execute_query(query): res = await asyncio.gather(*tasks) assert len(res) == 6 + + +@pytest.mark.asyncio +async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True, commit_on_exit=True): + await db.session.execute( + text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)") + ) + + async def insert_data(value): + await db.session.execute( + text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value} + ) + await db.session.flush() + + tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)] + + result_ids = await asyncio.gather(*tasks) + assert len(result_ids) == 10 + + records = await db.session.execute(text("SELECT * FROM my_model")) + records = records.scalars().all() + assert len(records) == 10 From 955f538fdd908f3fdae601ac52d798c6eac1e937 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Sat, 19 Oct 2024 01:12:50 +0300 Subject: [PATCH 13/14] WIP: multi_sessions --- fastapi_async_sqlalchemy/__init__.py | 2 +- fastapi_async_sqlalchemy/middleware.py | 39 +++++++++++--------------- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 2821653..8afd39f 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -2,4 +2,4 @@ __all__ = ["db", "SQLAlchemyMiddleware"] -__version__ = "0.7.0.dev2" +__version__ = "0.7.0.dev3" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 13cb359..1171ede 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -21,9 +21,6 @@ def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) - _task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar( - "_task_session_ctx", default=None - ) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and @@ -92,25 +89,22 @@ async def execute_query(query): ``` """ commit_on_exit = _commit_on_exit_ctx.get() - session = _task_session_ctx.get() - if session is None: - session = _Session() - _task_session_ctx.set(session) - - async def cleanup(): - try: - if commit_on_exit: - await session.commit() - except Exception: - await session.rollback() - raise - finally: - await session.close() - _task_session_ctx.set(None) - - task = asyncio.current_task() - if task is not None: - task.add_done_callback(lambda t: asyncio.create_task(cleanup())) + # Always create a new session for each access when multi_sessions=True + session = _Session() + + async def cleanup(): + try: + if commit_on_exit: + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + task = asyncio.current_task() + if task is not None: + task.add_done_callback(lambda t: asyncio.create_task(cleanup())) return session else: session = _session.get() @@ -126,7 +120,6 @@ def __init__( multi_sessions: bool = False, ): self.token = None - self.multi_sessions_token = None self.commit_on_exit_token = None self.session_args = session_args or {} self.commit_on_exit = commit_on_exit From a76b7bb67ea6e1fe90f342bbbe002d58ec1f876f Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Fri, 17 Jan 2025 10:47:00 +0200 Subject: [PATCH 14/14] fix import create_middleware_and_session_proxy --- README.md | 2 +- fastapi_async_sqlalchemy/__init__.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6d1e722..32898a2 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine. ### Install ```bash - pip install fastapi-async-sqlalchemy + pip install fastapi-async-sqlalchemy ``` diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 8afd39f..963466d 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -1,5 +1,9 @@ -from fastapi_async_sqlalchemy.middleware import SQLAlchemyMiddleware, db +from fastapi_async_sqlalchemy.middleware import ( + SQLAlchemyMiddleware, + db, + create_middleware_and_session_proxy, +) -__all__ = ["db", "SQLAlchemyMiddleware"] +__all__ = ["db", "SQLAlchemyMiddleware", "create_middleware_and_session_proxy"] -__version__ = "0.7.0.dev3" +__version__ = "0.7.0.dev4"