diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 91d300c..d819c8f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,7 +34,7 @@ jobs: strategy: fail-fast: false matrix: - python_version: [3.7, 3.8, 3.9, '3.10', '3.11'] + python_version: [3.9, '3.10', '3.11', '3.12', '3.13'] database_url: [ "sqlite+aiosqlite:///./test-fastapiusers.db", @@ -43,9 +43,9 @@ jobs: ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python_version }} - name: Install dependencies @@ -61,7 +61,7 @@ jobs: DATABASE_URL: ${{ matrix.database_url }} run: | hatch run test-cov-xml - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true @@ -78,11 +78,11 @@ jobs: if: startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: 3.7 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index b949f48..434348e 100644 --- a/.gitignore +++ b/.gitignore @@ -104,9 +104,6 @@ ENV/ # mypy .mypy_cache/ -# .vscode -.vscode/ - # OS files .DS_Store diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..5d8d955 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,21 @@ +{ + "python.analysis.typeCheckingMode": "basic", + "python.analysis.autoImportCompletions": true, + "python.terminal.activateEnvironment": true, + "python.terminal.activateEnvInCurrentTerminal": true, + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "editor.rulers": [88], + "python.defaultInterpreterPath": "${workspaceFolder}/.hatch/fastapi-users-db-sqlalchemy/bin/python", + "python.testing.pytestPath": "${workspaceFolder}/.hatch/fastapi-users-db-sqlalchemy/bin/pytest", + "python.testing.cwd": "${workspaceFolder}", + "python.testing.pytestArgs": ["--no-cov"], + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } + } diff --git a/README.md b/README.md index d9756d4..2f8498c 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [![PyPI version](https://badge.fury.io/py/fastapi-users-db-sqlalchemy.svg)](https://badge.fury.io/py/fastapi-users-db-sqlalchemy) [![Downloads](https://pepy.tech/badge/fastapi-users-db-sqlalchemy)](https://pepy.tech/project/fastapi-users-db-sqlalchemy)

- +

--- diff --git a/fastapi_users_db_sqlalchemy/__init__.py b/fastapi_users_db_sqlalchemy/__init__.py index 87d0733..467a2bf 100644 --- a/fastapi_users_db_sqlalchemy/__init__.py +++ b/fastapi_users_db_sqlalchemy/__init__.py @@ -1,22 +1,22 @@ """FastAPI Users database adapter for SQLAlchemy.""" + import uuid -from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type +from typing import TYPE_CHECKING, Any, Generic, Optional from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import ID, OAP, UP -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, func, select +from sqlalchemy import Boolean, ForeignKey, Integer, String, func, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import declarative_mixin, declared_attr +from sqlalchemy.orm import Mapped, declared_attr, mapped_column from sqlalchemy.sql import Select from fastapi_users_db_sqlalchemy.generics import GUID -__version__ = "4.0.5" +__version__ = "7.0.0" UUID_ID = uuid.UUID -@declarative_mixin class SQLAlchemyBaseUserTable(Generic[ID]): """Base SQLAlchemy users table definition.""" @@ -30,22 +30,28 @@ class SQLAlchemyBaseUserTable(Generic[ID]): is_superuser: bool is_verified: bool else: - email: str = Column(String(length=320), unique=True, index=True, nullable=False) - hashed_password: str = Column(String(length=1024), nullable=False) - is_active: bool = Column(Boolean, default=True, nullable=False) - is_superuser: bool = Column(Boolean, default=False, nullable=False) - is_verified: bool = Column(Boolean, default=False, nullable=False) + email: Mapped[str] = mapped_column( + String(length=320), unique=True, index=True, nullable=False + ) + hashed_password: Mapped[str] = mapped_column( + String(length=1024), nullable=False + ) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + is_superuser: Mapped[bool] = mapped_column( + Boolean, default=False, nullable=False + ) + is_verified: Mapped[bool] = mapped_column( + Boolean, default=False, nullable=False + ) -@declarative_mixin class SQLAlchemyBaseUserTableUUID(SQLAlchemyBaseUserTable[UUID_ID]): if TYPE_CHECKING: # pragma: no cover id: UUID_ID else: - id: UUID_ID = Column(GUID, primary_key=True, default=uuid.uuid4) + id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) -@declarative_mixin class SQLAlchemyBaseOAuthAccountTable(Generic[ID]): """Base SQLAlchemy OAuth account table definition.""" @@ -60,24 +66,32 @@ class SQLAlchemyBaseOAuthAccountTable(Generic[ID]): account_id: str account_email: str else: - oauth_name: str = Column(String(length=100), index=True, nullable=False) - access_token: str = Column(String(length=1024), nullable=False) - expires_at: Optional[int] = Column(Integer, nullable=True) - refresh_token: Optional[str] = Column(String(length=1024), nullable=True) - account_id: str = Column(String(length=320), index=True, nullable=False) - account_email: str = Column(String(length=320), nullable=False) + oauth_name: Mapped[str] = mapped_column( + String(length=100), index=True, nullable=False + ) + access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False) + expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + refresh_token: Mapped[Optional[str]] = mapped_column( + String(length=1024), nullable=True + ) + account_id: Mapped[str] = mapped_column( + String(length=320), index=True, nullable=False + ) + account_email: Mapped[str] = mapped_column(String(length=320), nullable=False) -@declarative_mixin class SQLAlchemyBaseOAuthAccountTableUUID(SQLAlchemyBaseOAuthAccountTable[UUID_ID]): if TYPE_CHECKING: # pragma: no cover id: UUID_ID + user_id: UUID_ID else: - id: UUID_ID = Column(GUID, primary_key=True, default=uuid.uuid4) + id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) - @declared_attr - def user_id(cls) -> Column[GUID]: - return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False) + @declared_attr + def user_id(cls) -> Mapped[GUID]: + return mapped_column( + GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False + ) class SQLAlchemyUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): @@ -90,14 +104,14 @@ class SQLAlchemyUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ session: AsyncSession - user_table: Type[UP] - oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] + user_table: type[UP] + oauth_account_table: Optional[type[SQLAlchemyBaseOAuthAccountTable]] def __init__( self, session: AsyncSession, - user_table: Type[UP], - oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] = None, + user_table: type[UP], + oauth_account_table: Optional[type[SQLAlchemyBaseOAuthAccountTable]] = None, ): self.session = session self.user_table = user_table @@ -120,19 +134,19 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP statement = ( select(self.user_table) .join(self.oauth_account_table) - .where(self.oauth_account_table.oauth_name == oauth) - .where(self.oauth_account_table.account_id == account_id) + .where(self.oauth_account_table.oauth_name == oauth) # type: ignore + .where(self.oauth_account_table.account_id == account_id) # type: ignore ) return await self._get_user(statement) - async def create(self, create_dict: Dict[str, Any]) -> UP: + async def create(self, create_dict: dict[str, Any]) -> UP: user = self.user_table(**create_dict) self.session.add(user) await self.session.commit() await self.session.refresh(user) return user - async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: + async def update(self, user: UP, update_dict: dict[str, Any]) -> UP: for key, value in update_dict.items(): setattr(user, key, value) self.session.add(user) @@ -144,10 +158,11 @@ async def delete(self, user: UP) -> None: await self.session.delete(user) await self.session.commit() - async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: + async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP: if self.oauth_account_table is None: raise NotImplementedError() + await self.session.refresh(user) oauth_account = self.oauth_account_table(**create_dict) self.session.add(oauth_account) user.oauth_accounts.append(oauth_account) # type: ignore @@ -158,7 +173,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: return user async def update_oauth_account( - self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + self, user: UP, oauth_account: OAP, update_dict: dict[str, Any] ) -> UP: if self.oauth_account_table is None: raise NotImplementedError() @@ -172,8 +187,4 @@ async def update_oauth_account( async def _get_user(self, statement: Select) -> Optional[UP]: results = await self.session.execute(statement) - user = results.first() - if user is None: - return None - - return user[0] + return results.unique().scalar_one_or_none() diff --git a/fastapi_users_db_sqlalchemy/access_token.py b/fastapi_users_db_sqlalchemy/access_token.py index 5913a8e..9f68af6 100644 --- a/fastapi_users_db_sqlalchemy/access_token.py +++ b/fastapi_users_db_sqlalchemy/access_token.py @@ -1,17 +1,16 @@ import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type +from typing import TYPE_CHECKING, Any, Generic, Optional from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase from fastapi_users.models import ID -from sqlalchemy import Column, ForeignKey, String, select +from sqlalchemy import ForeignKey, String, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import declarative_mixin, declared_attr +from sqlalchemy.orm import Mapped, declared_attr, mapped_column from fastapi_users_db_sqlalchemy.generics import GUID, TIMESTAMPAware, now_utc -@declarative_mixin class SQLAlchemyBaseAccessTokenTable(Generic[ID]): """Base SQLAlchemy access token table definition.""" @@ -22,21 +21,20 @@ class SQLAlchemyBaseAccessTokenTable(Generic[ID]): created_at: datetime user_id: ID else: - token: str = Column(String(length=43), primary_key=True) - created_at: datetime = Column( + token: Mapped[str] = mapped_column(String(length=43), primary_key=True) + created_at: Mapped[datetime] = mapped_column( TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc ) -@declarative_mixin class SQLAlchemyBaseAccessTokenTableUUID(SQLAlchemyBaseAccessTokenTable[uuid.UUID]): if TYPE_CHECKING: # pragma: no cover user_id: uuid.UUID else: @declared_attr - def user_id(cls) -> Column[GUID]: - return Column( + def user_id(cls) -> Mapped[GUID]: + return mapped_column( GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False ) @@ -52,7 +50,7 @@ class SQLAlchemyAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): def __init__( self, session: AsyncSession, - access_token_table: Type[AP], + access_token_table: type[AP], ): self.session = session self.access_token_table = access_token_table @@ -61,25 +59,24 @@ async def get_by_token( self, token: str, max_age: Optional[datetime] = None ) -> Optional[AP]: statement = select(self.access_token_table).where( - self.access_token_table.token == token + self.access_token_table.token == token # type: ignore ) if max_age is not None: - statement = statement.where(self.access_token_table.created_at >= max_age) + statement = statement.where( + self.access_token_table.created_at >= max_age # type: ignore + ) results = await self.session.execute(statement) - access_token = results.first() - if access_token is None: - return None - return access_token[0] + return results.scalar_one_or_none() - async def create(self, create_dict: Dict[str, Any]) -> AP: + async def create(self, create_dict: dict[str, Any]) -> AP: access_token = self.access_token_table(**create_dict) self.session.add(access_token) await self.session.commit() await self.session.refresh(access_token) return access_token - async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: + async def update(self, access_token: AP, update_dict: dict[str, Any]) -> AP: for key, value in update_dict.items(): setattr(access_token, key, value) self.session.add(access_token) diff --git a/fastapi_users_db_sqlalchemy/generics.py b/fastapi_users_db_sqlalchemy/generics.py index 5e4b46b..ddfe639 100644 --- a/fastapi_users_db_sqlalchemy/generics.py +++ b/fastapi_users_db_sqlalchemy/generics.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime, timezone +from typing import Optional from pydantic import UUID4 from sqlalchemy import CHAR, TIMESTAMP, TypeDecorator @@ -61,7 +62,7 @@ class TIMESTAMPAware(TypeDecorator): # pragma: no cover impl = TIMESTAMP cache_ok = True - def process_result_value(self, value: datetime, dialect): - if dialect.name != "postgresql": + def process_result_value(self, value: Optional[datetime], dialect): + if value is not None and dialect.name != "postgresql": return value.replace(tzinfo=timezone.utc) return value diff --git a/pyproject.toml b/pyproject.toml index 7583691..ef32d8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,11 +2,14 @@ plugins = "sqlalchemy.ext.mypy.plugin" [tool.pytest.ini_options] -asyncio_mode = "auto" +asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" addopts = "--ignore=test_build.py" [tool.ruff] -extend-select = ["I"] + +[tool.ruff.lint] +extend-select = ["I", "UP"] [tool.hatch] @@ -19,6 +22,7 @@ commit_extra_args = ["-e"] path = "fastapi_users_db_sqlalchemy/__init__.py" [tool.hatch.envs.default] +installer = "uv" dependencies = [ "aiosqlite", "asyncpg", @@ -40,13 +44,13 @@ dependencies = [ test = "pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=term-missing --cov-fail-under=100" test-cov-xml = "pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=xml --cov-fail-under=100" lint = [ - "black . ", - "ruff --fix .", + "ruff format . ", + "ruff check --fix .", "mypy fastapi_users_db_sqlalchemy/", ] lint-check = [ - "black --check .", - "ruff .", + "ruff format --check .", + "ruff check .", "mypy fastapi_users_db_sqlalchemy/", ] @@ -71,18 +75,18 @@ classifiers = [ "Framework :: FastAPI", "Framework :: AsyncIO", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3 :: Only", "Topic :: Internet :: WWW/HTTP :: Session", ] -requires-python = ">=3.7" +requires-python = ">=3.9" dependencies = [ "fastapi-users >= 10.0.0", - "sqlalchemy[asyncio] >=1.4,<2.0.0", + "sqlalchemy[asyncio] >=2.0.0,<2.1.0", ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index 17d3c79..b7661c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ -import asyncio import os -from typing import Any, Dict, Optional +from typing import Any, Optional import pytest from fastapi_users import schemas @@ -26,16 +25,8 @@ class UserOAuth(User, schemas.BaseOAuthAccountMixin): pass -@pytest.fixture(scope="session") -def event_loop(): - """Force the pytest-asyncio loop to be the main one.""" - loop = asyncio.new_event_loop() - yield loop - loop.close() - - @pytest.fixture -def oauth_account1() -> Dict[str, Any]: +def oauth_account1() -> dict[str, Any]: return { "oauth_name": "service1", "access_token": "TOKEN", @@ -46,7 +37,7 @@ def oauth_account1() -> Dict[str, Any]: @pytest.fixture -def oauth_account2() -> Dict[str, Any]: +def oauth_account2() -> dict[str, Any]: return { "oauth_name": "service2", "access_token": "TOKEN", diff --git a/tests/test_access_token.py b/tests/test_access_token.py index 02ad09a..f0e5fb9 100644 --- a/tests/test_access_token.py +++ b/tests/test_access_token.py @@ -1,12 +1,18 @@ import uuid +from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone -from typing import AsyncGenerator import pytest +import pytest_asyncio from pydantic import UUID4 from sqlalchemy import exc -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import DeclarativeBase from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import ( @@ -15,7 +21,9 @@ ) from tests.conftest import DATABASE_URL -Base = declarative_base() + +class Base(DeclarativeBase): + pass class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -27,7 +35,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): def create_async_session_maker(engine: AsyncEngine): - return sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @pytest.fixture @@ -35,7 +43,7 @@ def user_id() -> UUID4: return uuid.uuid4() -@pytest.fixture +@pytest_asyncio.fixture async def sqlalchemy_access_token_db( user_id: UUID4, ) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase[AccessToken], None]: diff --git a/tests/test_users.py b/tests/test_users.py index 3bcdc34..4a83e39 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,9 +1,20 @@ -from typing import Any, AsyncGenerator, Dict, List +from collections.abc import AsyncGenerator +from typing import Any import pytest -from sqlalchemy import Column, String, exc -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from sqlalchemy.orm import declarative_base, relationship, sessionmaker +import pytest_asyncio +from sqlalchemy import String, exc +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + relationship, +) from fastapi_users_db_sqlalchemy import ( UUID_ID, @@ -15,17 +26,19 @@ def create_async_session_maker(engine: AsyncEngine): - return sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + return async_sessionmaker(engine, expire_on_commit=False) -Base = declarative_base() +class Base(DeclarativeBase): + pass class User(SQLAlchemyBaseUserTableUUID, Base): - first_name = Column(String(255), nullable=True) + first_name: Mapped[str] = mapped_column(String(255), nullable=True) -OAuthBase = declarative_base() +class OAuthBase(DeclarativeBase): + pass class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, OAuthBase): @@ -33,11 +46,13 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, OAuthBase): class UserOAuth(SQLAlchemyBaseUserTableUUID, OAuthBase): - first_name = Column(String(255), nullable=True) - oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined") + first_name: Mapped[str] = mapped_column(String(255), nullable=True) + oauth_accounts: Mapped[list[OAuthAccount]] = relationship( + "OAuthAccount", lazy="joined" + ) -@pytest.fixture +@pytest_asyncio.fixture async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: engine = create_async_engine(DATABASE_URL) sessionmaker = create_async_session_maker(engine) @@ -52,7 +67,7 @@ async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: await connection.run_sync(Base.metadata.drop_all) -@pytest.fixture +@pytest_asyncio.fixture async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: engine = create_async_engine(DATABASE_URL) sessionmaker = create_async_session_maker(engine) @@ -155,8 +170,8 @@ async def test_queries_custom_fields( @pytest.mark.asyncio async def test_queries_oauth( sqlalchemy_user_db_oauth: SQLAlchemyUserDatabase[UserOAuth, UUID_ID], - oauth_account1: Dict[str, Any], - oauth_account2: Dict[str, Any], + oauth_account1: dict[str, Any], + oauth_account2: dict[str, Any], ): user_create = { "email": "lancelot@camelot.bt",