diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7723dcc..32ab205 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,8 +15,8 @@ repos: - id: pyupgrade args: - --py37-plus - - repo: https://github.com/myint/autoflake - rev: v1.4 + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.1 hooks: - id: autoflake args: @@ -30,7 +30,7 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.12.1 hooks: - id: black - repo: https://github.com/PyCQA/flake8 diff --git a/README.md b/README.md index 32898a2..81cadff 100644 --- a/README.md +++ b/README.md @@ -167,3 +167,114 @@ async def parallel_select(): await asyncio.gather(*tasks) ``` + +#### Using SQLAlchemy Events + +SQLAlchemy events work seamlessly with this package. You can listen to various events like `before_insert`, `after_insert`, `before_update`, `after_update`, etc. + +```python +from fastapi import FastAPI +from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db +from sqlalchemy import Column, Integer, String, event +from sqlalchemy.orm import DeclarativeBase + +app = FastAPI() +app.add_middleware( + SQLAlchemyMiddleware, + db_url="postgresql+asyncpg://user:user@192.168.88.200:5432/primary_db", +) + +class Base(DeclarativeBase): + pass + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + email = Column(String(100)) + +# Event listeners +@event.listens_for(User, 'before_insert') +def before_insert_listener(mapper, connection, target): + print(f"About to insert user: {target.name}") + # You can modify the target object here + target.name = target.name.title() # Capitalize name + +@event.listens_for(User, 'after_insert') +def after_insert_listener(mapper, connection, target): + print(f"User {target.name} inserted with ID: {target.id}") + +@event.listens_for(User, 'before_update') +def before_update_listener(mapper, connection, target): + print(f"About to update user: {target.name}") + +@event.listens_for(User, 'after_update') +def after_update_listener(mapper, connection, target): + print(f"User {target.name} updated") + +# Usage in routes +@app.post("/users") +async def create_user(name: str, email: str): + user = User(name=name, email=email) + db.session.add(user) + await db.session.commit() + return {"id": user.id, "name": user.name, "email": user.email} + +@app.put("/users/{user_id}") +async def update_user(user_id: int, name: str): + user = await db.session.get(User, user_id) + if user: + user.name = name + await db.session.commit() + return {"id": user.id, "name": user.name} + return {"error": "User not found"} +``` + +#### Advanced Event Usage + +You can also use events for more complex scenarios like auditing, validation, or triggering other actions: + +```python +from datetime import datetime +from sqlalchemy import Column, DateTime, event + +class AuditMixin: + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + +class User(Base, AuditMixin): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + email = Column(String(100)) + +# Validation event +@event.listens_for(User, 'before_insert') +@event.listens_for(User, 'before_update') +def validate_user(mapper, connection, target): + if not target.email or '@' not in target.email: + raise ValueError("Invalid email address") + if not target.name or len(target.name.strip()) < 2: + raise ValueError("Name must be at least 2 characters long") + +# Audit logging event +@event.listens_for(User, 'after_insert') +@event.listens_for(User, 'after_update') +@event.listens_for(User, 'after_delete') +def audit_user_changes(mapper, connection, target): + # Log changes to audit table or external system + action = "INSERT" if mapper.class_.__name__ == "User" else "UPDATE" + print(f"AUDIT: {action} on User {target.id} at {datetime.utcnow()}") +``` + +All SQLAlchemy event types are supported: +- `before_insert`, `after_insert` +- `before_update`, `after_update` +- `before_delete`, `after_delete` +- `before_bulk_insert`, `after_bulk_insert` +- `before_bulk_update`, `after_bulk_update` +- `before_bulk_delete`, `after_bulk_delete` + +Events work with both single database setups and multiple database configurations. diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 963466d..51a793a 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -1,7 +1,7 @@ from fastapi_async_sqlalchemy.middleware import ( SQLAlchemyMiddleware, - db, create_middleware_and_session_proxy, + db, ) __all__ = ["db", "SQLAlchemyMiddleware", "create_middleware_and_session_proxy"] diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 1171ede..2d06db6 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,4 +1,5 @@ import asyncio +import weakref from contextvars import ContextVar from typing import Dict, Optional, Union @@ -17,14 +18,123 @@ from sqlalchemy.orm import sessionmaker as async_sessionmaker -def create_middleware_and_session_proxy(): - _Session: Optional[async_sessionmaker] = None - _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) - _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) - _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) - # Usage of context vars inside closures is not recommended, since they are not properly - # garbage collected, but in our use case context var is created on program startup and - # is used throughout the whole its lifecycle. +class TaskSessionManager: + """Manages sessions per asyncio task with automatic cleanup.""" + + def __init__(self, max_sessions_per_task: int = 5): + self.max_sessions_per_task = max_sessions_per_task + self._task_sessions_map: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._lock = asyncio.Lock() + + def get_or_create_session(self, session_factory, commit_on_exit: bool = False) -> AsyncSession: + """Get existing session for current task or create a new one.""" + task = asyncio.current_task() + if task is None: + # No current task, return new session (will need manual cleanup) + return session_factory() + + # Note: We can't use async lock here as this method is called from sync context + # For thread safety in asyncio, we rely on the GIL and single-threaded event loop + # Get or create session list for this task + if task not in self._task_sessions_map: + self._task_sessions_map[task] = [] + + task_sessions = self._task_sessions_map[task] + + # Check session limit + if len(task_sessions) >= self.max_sessions_per_task: + raise RuntimeError(f"Maximum sessions per task ({self.max_sessions_per_task}) exceeded") + + # Create new session + session = session_factory() + task_sessions.append((session, commit_on_exit)) + + # Add cleanup callback if this is the first session for the task + if len(task_sessions) == 1: + task.add_done_callback(self._cleanup_task_sessions) + + return session + + def _cleanup_task_sessions(self, task): + """Cleanup all sessions when task completes.""" + if task in self._task_sessions_map: + sessions = self._task_sessions_map[task] + # Schedule cleanup in event loop and store reference to prevent garbage collection + cleanup_task = asyncio.create_task(self._async_cleanup_sessions(sessions)) + # Add error callback to handle any cleanup failures + cleanup_task.add_done_callback(self._handle_cleanup_result) + + def _handle_cleanup_result(self, task): + """Handle cleanup task completion and suppress any errors.""" + try: + task.result() # This will raise if the task failed + except Exception: + pass # Suppress cleanup errors to prevent unhandled exceptions + + async def _async_cleanup_sessions(self, sessions): + """Async cleanup of session list.""" + for session, commit_on_exit in sessions: + try: + if commit_on_exit: + await session.commit() + except Exception: + try: + await session.rollback() + except Exception: + pass # Suppress rollback errors + finally: + try: + await session.close() + except Exception: + pass # Suppress close errors + + async def cleanup_all_sessions(self, exc_type=None, commit_on_exit=False): + """Cleanup all tracked sessions (for context manager exit).""" + async with self._lock: + all_sessions = [] + for task_sessions in self._task_sessions_map.values(): + all_sessions.extend(task_sessions) + + # Clear the map first to prevent double cleanup + self._task_sessions_map.clear() + + for session, session_commit_on_exit in all_sessions: + original_exception = None + try: + if exc_type is not None: + await session.rollback() + elif commit_on_exit or session_commit_on_exit: + await session.commit() + except Exception as e: + original_exception = e + try: + await session.rollback() + except Exception: + pass # Suppress rollback errors + finally: + try: + await session.close() + except Exception as close_error: + if original_exception is None: + original_exception = close_error + + # Re-raise the original exception if there was one + if original_exception is not None: + raise original_exception + + +# Module-level context variables to avoid memory leaks +_Session: Optional[async_sessionmaker] = None +_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) +_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) +_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) +_task_session_manager: Optional[TaskSessionManager] = None + + +def create_middleware_and_session_proxy(max_sessions_per_task: int = 5): + global _Session, _task_session_manager + # Initialize task session manager + _task_session_manager = TaskSessionManager(max_sessions_per_task=max_sessions_per_task) class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -48,7 +158,7 @@ def __init__( else: engine = custom_engine - nonlocal _Session + global _Session _Session = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, **session_args ) @@ -66,10 +176,9 @@ def session(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions: - """In this case, we need to create a new session for each task. - We also need to commit the session on exit if commit_on_exit is True. - This is useful when we need to run multiple queries in parallel. - For example, when we need to run multiple queries in parallel in a route handler. + """In multi_sessions mode, create new sessions per task. + Sessions are automatically cleaned up when tasks complete. + Example: ```python async with db(multi_sessions=True): @@ -79,33 +188,14 @@ async def execute_query(query): tasks = [ asyncio.create_task(execute_query("SELECT 1")), asyncio.create_task(execute_query("SELECT 2")), - asyncio.create_task(execute_query("SELECT 3")), - asyncio.create_task(execute_query("SELECT 4")), - asyncio.create_task(execute_query("SELECT 5")), - asyncio.create_task(execute_query("SELECT 6")), ] - await asyncio.gather(*tasks) ``` """ commit_on_exit = _commit_on_exit_ctx.get() - # Always create a new session for each access when multi_sessions=True - session = _Session() - - async def cleanup(): - try: - if commit_on_exit: - await session.commit() - except Exception: - await session.rollback() - raise - finally: - await session.close() - - task = asyncio.current_task() - if task is not None: - task.add_done_callback(lambda t: asyncio.create_task(cleanup())) - return session + if _task_session_manager is None: + raise SessionNotInitialisedError + return _task_session_manager.get_or_create_session(_Session, commit_on_exit) else: session = _session.get() if session is None: @@ -138,8 +228,16 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): if self.multi_sessions: - _multi_sessions_ctx.reset(self.multi_sessions_token) - _commit_on_exit_ctx.reset(self.commit_on_exit_token) + try: + # Cleanup all sessions managed by the task session manager + if _task_session_manager is not None: + await _task_session_manager.cleanup_all_sessions( + exc_type, self.commit_on_exit + ) + finally: + # Always reset context variables even if cleanup fails + _multi_sessions_ctx.reset(self.multi_sessions_token) + _commit_on_exit_ctx.reset(self.commit_on_exit_token) else: session = _session.get() try: