diff --git a/README.md b/README.md index 32898a2..e438988 100644 --- a/README.md +++ b/README.md @@ -167,3 +167,114 @@ async def parallel_select(): await asyncio.gather(*tasks) ``` + +#### Using SQLAlchemy Events + +SQLAlchemy events work seamlessly with this package. You can listen to various events like `before_insert`, `after_insert`, `before_update`, `after_update`, etc. + +```python +from fastapi import FastAPI +from fastapi_async_sqlalchemy import SQLAlchemyMiddleware, db +from sqlalchemy import Column, Integer, String, event +from sqlalchemy.orm import DeclarativeBase + +app = FastAPI() +app.add_middleware( + SQLAlchemyMiddleware, + db_url="postgresql+asyncpg://user:user@192.168.88.200:5432/primary_db", +) + +class Base(DeclarativeBase): + pass + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + email = Column(String(100)) + +# Event listeners +@event.listens_for(User, 'before_insert') +def before_insert_listener(mapper, connection, target): + print(f"About to insert user: {target.name}") + # You can modify the target object here + target.name = target.name.title() # Capitalize name + +@event.listens_for(User, 'after_insert') +def after_insert_listener(mapper, connection, target): + print(f"User {target.name} inserted with ID: {target.id}") + +@event.listens_for(User, 'before_update') +def before_update_listener(mapper, connection, target): + print(f"About to update user: {target.name}") + +@event.listens_for(User, 'after_update') +def after_update_listener(mapper, connection, target): + print(f"User {target.name} updated") + +# Usage in routes +@app.post("/users") +async def create_user(name: str, email: str): + user = User(name=name, email=email) + db.session.add(user) + await db.session.commit() + return {"id": user.id, "name": user.name, "email": user.email} + +@app.put("/users/{user_id}") +async def update_user(user_id: int, name: str): + user = await db.session.get(User, user_id) + if user: + user.name = name + await db.session.commit() + return {"id": user.id, "name": user.name} + return {"error": "User not found"} +``` + +#### Advanced Event Usage + +You can also use events for more complex scenarios like auditing, validation, or triggering other actions: + +```python +from datetime import datetime +from sqlalchemy import Column, DateTime, event + +class AuditMixin: + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + +class User(Base, AuditMixin): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + email = Column(String(100)) + +# Validation event +@event.listens_for(User, 'before_insert') +@event.listens_for(User, 'before_update') +def validate_user(mapper, connection, target): + if not target.email or '@' not in target.email: + raise ValueError("Invalid email address") + if not target.name or len(target.name.strip()) < 2: + raise ValueError("Name must be at least 2 characters long") + +# Audit logging event +@event.listens_for(User, 'after_insert') +@event.listens_for(User, 'after_update') +@event.listens_for(User, 'after_delete') +def audit_user_changes(mapper, connection, target): + # Log changes to audit table or external system + action = "INSERT" if mapper.class_.__name__ == "User" else "UPDATE" + print(f"AUDIT: {action} on User {target.id} at {datetime.utcnow()}") +``` + +All SQLAlchemy event types are supported: +- `before_insert`, `after_insert` +- `before_update`, `after_update` +- `before_delete`, `after_delete` +- `before_bulk_insert`, `after_bulk_insert` +- `before_bulk_update`, `after_bulk_update` +- `before_bulk_delete`, `after_bulk_delete` + +Events work with both single database setups and multiple database configurations. diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 1171ede..ba466fa 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -22,6 +22,9 @@ def create_middleware_and_session_proxy(): _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) + # Store sessions per task for multi_sessions mode to enable reuse and proper cleanup + _task_sessions: ContextVar[Dict[int, AsyncSession]] = ContextVar("_task_sessions", default={}) + _cleanup_tasks: ContextVar[set] = ContextVar("_cleanup_tasks", default=set()) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and # is used throughout the whole its lifecycle. @@ -66,7 +69,7 @@ def session(self) -> AsyncSession: multi_sessions = _multi_sessions_ctx.get() if multi_sessions: - """In this case, we need to create a new session for each task. + """In this case, we reuse sessions per task for better performance. We also need to commit the session on exit if commit_on_exit is True. This is useful when we need to run multiple queries in parallel. For example, when we need to run multiple queries in parallel in a route handler. @@ -88,11 +91,28 @@ async def execute_query(query): await asyncio.gather(*tasks) ``` """ + task = asyncio.current_task() + if task is None: + # Fallback: create a new session if no current task + return _Session() + + task_id = id(task) + task_sessions = _task_sessions.get() + + # Reuse existing session for this task if available + if task_id in task_sessions: + return task_sessions[task_id] + + # Create new session for this task commit_on_exit = _commit_on_exit_ctx.get() - # Always create a new session for each access when multi_sessions=True session = _Session() + + # Store session for reuse within this task + task_sessions[task_id] = session + _task_sessions.set(task_sessions) - async def cleanup(): + async def cleanup_session(): + """Properly cleanup session when task completes.""" try: if commit_on_exit: await session.commit() @@ -101,10 +121,33 @@ async def cleanup(): raise finally: await session.close() - - task = asyncio.current_task() - if task is not None: - task.add_done_callback(lambda t: asyncio.create_task(cleanup())) + # Remove from task sessions + current_sessions = _task_sessions.get() + current_sessions.pop(task_id, None) + _task_sessions.set(current_sessions) + + def task_done_callback(completed_task): + """Callback to schedule cleanup when task is done.""" + # Schedule cleanup to run in the next event loop iteration + # This ensures it runs after the current task completes + def schedule_cleanup(): + cleanup_tasks = _cleanup_tasks.get() + cleanup_task = asyncio.create_task(cleanup_session()) + cleanup_tasks.add(cleanup_task) + _cleanup_tasks.set(cleanup_tasks) + + # Remove completed cleanup task from tracking + def cleanup_done(ct): + current_cleanup_tasks = _cleanup_tasks.get() + current_cleanup_tasks.discard(ct) + _cleanup_tasks.set(current_cleanup_tasks) + + cleanup_task.add_done_callback(cleanup_done) + + # Use call_soon to ensure cleanup is scheduled properly + asyncio.get_event_loop().call_soon(schedule_cleanup) + + task.add_done_callback(task_done_callback) return session else: session = _session.get() @@ -138,6 +181,37 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): if self.multi_sessions: + # Clean up any remaining sessions and wait for cleanup tasks + task_sessions = _task_sessions.get() + cleanup_tasks = _cleanup_tasks.get() + + # Clean up any remaining sessions that weren't cleaned up by task callbacks + for task_id, session in list(task_sessions.items()): + try: + if exc_type is not None: + await session.rollback() + elif self.commit_on_exit: + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + task_sessions.pop(task_id, None) + + # Wait for all cleanup tasks to complete + if cleanup_tasks: + try: + await asyncio.gather(*cleanup_tasks, return_exceptions=True) + except Exception: + # Cleanup tasks should handle their own exceptions + pass + finally: + cleanup_tasks.clear() + + # Reset context variables + _task_sessions.set({}) + _cleanup_tasks.set(set()) _multi_sessions_ctx.reset(self.multi_sessions_token) _commit_on_exit_ctx.reset(self.commit_on_exit_token) else: