diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd1026a..808c1c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,26 +13,23 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - build: [linux_3.8, windows_3.8, mac_3.8, linux_3.7] + build: [linux_3.9, windows_3.9, mac_3.9] include: - - build: linux_3.8 + - build: linux_3.9 os: ubuntu-latest - python: 3.8 - - build: windows_3.8 + python: 3.9 + - build: windows_3.9 os: windows-latest - python: 3.8 - - build: mac_3.8 + python: 3.9 + - build: mac_3.9 os: macos-latest - python: 3.8 - - build: linux_3.7 - os: ubuntu-latest - python: 3.7 + python: 3.9 steps: - name: Checkout repository uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} @@ -43,16 +40,16 @@ jobs: # test all the builds apart from linux_3.8... - name: Test with pytest - if: matrix.build != 'linux_3.8' + if: matrix.build != 'linux_3.9' run: pytest # only do the test coverage for linux_3.8 - name: Produce coverage report - if: matrix.build == 'linux_3.8' + if: matrix.build == 'linux_3.9' run: pytest --cov=fastapi_async_sqlalchemy --cov-report=xml - name: Upload coverage report - if: matrix.build == 'linux_3.8' + if: matrix.build == 'linux_3.9' uses: codecov/codecov-action@v1 with: file: ./coverage.xml @@ -65,9 +62,9 @@ jobs: uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.9 - name: Install dependencies run: pip install flake8 @@ -83,9 +80,9 @@ jobs: uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.9 - name: Install dependencies # isort needs all of the packages to be installed so it can diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 1a03a7b..bdaab28 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,11 +1,19 @@ # This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. name: Upload Python Package on: release: - types: [created] + types: [published] + +permissions: + contents: read jobs: deploy: @@ -13,19 +21,19 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: '3.x' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine - - name: Build and publish - env: - TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} - TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} - run: | - python setup.py sdist bdist_wheel - twine upload dist/* + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/README.md b/README.md index 56d3a58..32898a2 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![codecov](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy/branch/main/graph/badge.svg?token=F4NJ34WKPY)](https://codecov.io/gh/h0rn3t/fastapi-async-sqlalchemy) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![pip](https://img.shields.io/pypi/v/fastapi_async_sqlalchemy?color=blue)](https://pypi.org/project/fastapi-async-sqlalchemy/) -[![Downloads](https://pepy.tech/badge/fastapi-async-sqlalchemy)](https://pepy.tech/project/fastapi-async-sqlalchemy) +[![Downloads](https://static.pepy.tech/badge/fastapi-async-sqlalchemy)](https://pepy.tech/project/fastapi-async-sqlalchemy) [![Updates](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/shield.svg)](https://pyup.io/repos/github/h0rn3t/fastapi-async-sqlalchemy/) ### Description @@ -15,11 +15,11 @@ Provides SQLAlchemy middleware for FastAPI using AsyncSession and async engine. ### Install ```bash - pip install fastapi-async-sqlalchemy + pip install fastapi-async-sqlalchemy ``` -### Important !!! -If you use ```sqlmodel``` install ```sqlalchemy<=1.4.41``` + +It also works with ```sqlmodel``` ### Examples @@ -127,9 +127,10 @@ app.add_middleware( routes.py ```python +import asyncio + from fastapi import APIRouter -from sqlalchemy import column -from sqlalchemy import table +from sqlalchemy import column, table, text from databases import first_db, second_db @@ -147,4 +148,22 @@ async def get_files_from_first_db(): async def get_files_from_second_db(): result = await second_db.session.execute(foo.select()) return result.fetchall() -``` \ No newline at end of file + + +@router.get("/concurrent-queries") +async def parallel_select(): + async with first_db(multi_sessions=True): + async def execute_query(query): + return await first_db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) +``` diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 0a44930..963466d 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -1,5 +1,9 @@ -from fastapi_async_sqlalchemy.middleware import SQLAlchemyMiddleware, db +from fastapi_async_sqlalchemy.middleware import ( + SQLAlchemyMiddleware, + db, + create_middleware_and_session_proxy, +) -__all__ = ["db", "SQLAlchemyMiddleware"] +__all__ = ["db", "SQLAlchemyMiddleware", "create_middleware_and_session_proxy"] -__version__ = "0.6.0" +__version__ = "0.7.0.dev4" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 1a407fa..1171ede 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,9 +1,10 @@ +import asyncio from contextvars import ContextVar from typing import Dict, Optional, Union from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.types import ASGIApp @@ -11,17 +12,19 @@ from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError try: - from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811 except ImportError: from sqlalchemy.orm import sessionmaker as async_sessionmaker def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None + _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) + _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) + _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and # is used throughout the whole its lifecycle. - _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -51,7 +54,7 @@ def __init__( ) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): - async with db(commit_on_exit=self.commit_on_exit): + async with DBSession(commit_on_exit=self.commit_on_exit): return await call_next(request) class DBSessionMeta(type): @@ -61,38 +64,92 @@ def session(self) -> AsyncSession: if _Session is None: raise SessionNotInitialisedError - session = _session.get() - if session is None: - raise MissingSessionError - - return session + multi_sessions = _multi_sessions_ctx.get() + if multi_sessions: + """In this case, we need to create a new session for each task. + We also need to commit the session on exit if commit_on_exit is True. + This is useful when we need to run multiple queries in parallel. + For example, when we need to run multiple queries in parallel in a route handler. + Example: + ```python + async with db(multi_sessions=True): + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) + ``` + """ + commit_on_exit = _commit_on_exit_ctx.get() + # Always create a new session for each access when multi_sessions=True + session = _Session() + + async def cleanup(): + try: + if commit_on_exit: + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + task = asyncio.current_task() + if task is not None: + task.add_done_callback(lambda t: asyncio.create_task(cleanup())) + return session + else: + session = _session.get() + if session is None: + raise MissingSessionError + return session class DBSession(metaclass=DBSessionMeta): - def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): + def __init__( + self, + session_args: Dict = None, + commit_on_exit: bool = False, + multi_sessions: bool = False, + ): self.token = None + self.commit_on_exit_token = None self.session_args = session_args or {} self.commit_on_exit = commit_on_exit + self.multi_sessions = multi_sessions async def __aenter__(self): if not isinstance(_Session, async_sessionmaker): raise SessionNotInitialisedError - self.token = _session.set(_Session(**self.session_args)) # type: ignore + if self.multi_sessions: + self.multi_sessions_token = _multi_sessions_ctx.set(True) + self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit) + else: + self.token = _session.set(_Session(**self.session_args)) return type(self) async def __aexit__(self, exc_type, exc_value, traceback): - session = _session.get() - - try: - if exc_type is not None: - await session.rollback() - elif ( - self.commit_on_exit - ): # Note: Changed this to elif to avoid commit after rollback - await session.commit() - finally: - await session.close() - _session.reset(self.token) + if self.multi_sessions: + _multi_sessions_ctx.reset(self.multi_sessions_token) + _commit_on_exit_ctx.reset(self.commit_on_exit_token) + else: + session = _session.get() + try: + if exc_type is not None: + await session.rollback() + elif self.commit_on_exit: + await session.commit() + finally: + await session.close() + _session.reset(self.token) return SQLAlchemyMiddleware, DBSession diff --git a/requirements.txt b/requirements.txt index bb00aee..e3a0644 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,35 +8,35 @@ coverage>=5.2.1 entrypoints==0.3 fastapi==0.90.0 # pyup: ignore flake8==3.7.9 -idna==2.8 +idna==3.7 importlib-metadata==1.5.0 isort==4.3.21 mccabe==0.6.1 more-itertools==7.2.0 -packaging==19.2 +packaging>=22.0 pathspec>=0.9.0 pluggy==0.13.0 pycodestyle==2.5.0 -pydantic==1.10.2 +pydantic==1.10.18 pyflakes==2.1.1 pyparsing==2.4.2 pytest==7.2.0 pytest-cov==2.11.1 PyYAML>=5.4 -regex==2020.2.20 +regex>=2020.2.20 requests>=2.22.0 httpx>=0.20.0 six==1.12.0 SQLAlchemy>=1.4.19 asyncpg>=0.27.0 -aiosqlite==0.19.0 -sqlparse==0.4.4 +aiosqlite==0.20.0 +sqlparse==0.5.1 starlette>=0.13.6 toml>=0.10.1 typed-ast>=1.4.2 urllib3>=1.25.9 wcwidth==0.1.7 -zipp==3.1.0 -black==22.3.0 -pytest-asyncio>=0.15.0 -greenlet==2.0.2 +zipp==3.19.1 +black==24.4.2 +pytest-asyncio==0.21.0 +greenlet==3.1.1 diff --git a/setup.py b/setup.py index d9b0d89..09786b6 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ python_requires=">=3.7", install_requires=["starlette>=0.13.6", "SQLAlchemy>=1.4.19"], classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", "Framework :: AsyncIO", "Intended Audience :: Developers", @@ -40,6 +40,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Internet :: WWW/HTTP :: HTTP Servers", diff --git a/tests/test_session.py b/tests/test_session.py index 06163ac..9400fea 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,7 @@ +import asyncio + import pytest +from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware @@ -148,3 +151,50 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_ session_args = {"expire_on_commit": False} async with db(session_args=session_args): db.session + + +@pytest.mark.asyncio +async def test_multi_sessions(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True): + + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + res = await asyncio.gather(*tasks) + assert len(res) == 6 + + +@pytest.mark.asyncio +async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True, commit_on_exit=True): + await db.session.execute( + text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)") + ) + + async def insert_data(value): + await db.session.execute( + text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value} + ) + await db.session.flush() + + tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)] + + result_ids = await asyncio.gather(*tasks) + assert len(result_ids) == 10 + + records = await db.session.execute(text("SELECT * FROM my_model")) + records = records.scalars().all() + assert len(records) == 10