diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f6161fc4..077354e8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -25,7 +25,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 pypy3.8 pypy3.9 pypy3.10 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 sccache: 'true' manylinux: auto before-script-linux: | @@ -70,7 +70,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v3 @@ -110,7 +110,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 pypy3.8 pypy3.9 pypy3.10 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v3 @@ -164,7 +164,7 @@ jobs: uses: messense/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 pypy3.8 pypy3.9 pypy3.10 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 manylinux: musllinux_1_2 - name: Upload wheels uses: actions/upload-artifact@v3 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7038d412..7d81fbd3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -30,12 +30,12 @@ jobs: - uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - args: -p psqlpy --all-features -- -W clippy::all -W clippy::pedantic -D warnings + args: -p psqlpy --all-features -- -W clippy::all -W clippy::pedantic pytest: name: ${{matrix.job.os}}-${{matrix.py_version}} strategy: matrix: - py_version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] job: - os: ubuntu-latest ssl_cmd: sudo apt-get update && sudo apt-get install libssl-dev openssl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7bd77650..bef1e2c0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,8 +58,6 @@ repos: - clippy::all - -W - clippy::pedantic - - -D - - warnings - id: check types: diff --git a/Cargo.lock b/Cargo.lock index ce6f31cf..abf26770 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -997,13 +997,15 @@ dependencies = [ [[package]] name = "psqlpy" -version = "0.8.7" +version = "0.9.0" dependencies = [ "byteorder", "bytes", "chrono", "chrono-tz", "deadpool-postgres", + "futures", + "futures-channel", "futures-util", "geo-types", "itertools", @@ -1050,9 +1052,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" +checksum = "57fe09249128b3173d092de9523eaa75136bf7ba85e0d69eca241c7939c933cc" dependencies = [ "cfg-if", "chrono", @@ -1070,8 +1072,8 @@ dependencies = [ [[package]] name = "pyo3-async-runtimes" -version = "0.21.0" -source = "git+https://github.com/chandr-andr/pyo3-async-runtimes.git?branch=main#284bd36d0426a988026f878cae22abdb179795e6" +version = "0.23.0" +source = "git+https://github.com/chandr-andr/pyo3-async-runtimes.git?branch=psqlpy#c2b8441b4910b0b5100536b23c7a2fd43f9eacd0" dependencies = [ "futures", "once_cell", @@ -1082,9 +1084,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" +checksum = "1cd3927b5a78757a0d71aa9dff669f903b1eb64b54142a9bd9f757f8fde65fd7" dependencies = [ "once_cell", "target-lexicon", @@ -1092,9 +1094,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" +checksum = "dab6bb2102bd8f991e7749f130a70d05dd557613e39ed2deeee8e9ca0c4d548d" dependencies = [ "libc", "pyo3-build-config", @@ -1102,9 +1104,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" +checksum = "91871864b353fd5ffcb3f91f2f703a22a9797c91b9ab497b1acac7b07ae509c7" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1114,9 +1116,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" +checksum = "43abc3b80bc20f3facd86cd3c60beed58c3e2aa26213f3cda368de39c60a27e4" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 0cdc86fe..5e7743de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "psqlpy" -version = "0.8.7" +version = "0.9.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,15 +10,8 @@ crate-type = ["cdylib"] [dependencies] deadpool-postgres = { git = "https://github.com/chandr-andr/deadpool.git", branch = "psqlpy" } -pyo3 = { version = "*", features = [ - "chrono", - "experimental-async", - "rust_decimal", - "py-clone", - "gil-refs", - "macros", -] } -pyo3-async-runtimes = { git = "https://github.com/chandr-andr/pyo3-async-runtimes.git", branch = "main", features = [ +pyo3 = { version = "0.23.4", features = ["chrono", "experimental-async", "rust_decimal", "py-clone", "macros"] } +pyo3-async-runtimes = { git = "https://github.com/chandr-andr/pyo3-async-runtimes.git", branch = "psqlpy", features = [ "tokio-runtime", ] } @@ -59,3 +52,5 @@ pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = "psqlpy", features = [ "postgres", ] } +futures-channel = "0.3.31" +futures = "0.3.31" diff --git a/docs/.vuepress/sidebar.ts b/docs/.vuepress/sidebar.ts index 81da5e8f..133833f7 100644 --- a/docs/.vuepress/sidebar.ts +++ b/docs/.vuepress/sidebar.ts @@ -23,6 +23,7 @@ export default sidebar({ "connection", "transaction", "cursor", + "listener", "results", "exceptions", ], diff --git a/docs/components/components_overview.md b/docs/components/components_overview.md index 30cf84e3..90b05b70 100644 --- a/docs/components/components_overview.md +++ b/docs/components/components_overview.md @@ -8,6 +8,7 @@ title: Components - `Connection`: represents single database connection, can be retrieved from `ConnectionPool`. - `Transaction`: represents database transaction, can be made from `Connection`. - `Cursor`: represents database cursor, can be made from `Transaction`. +- `Listener`: object to work with [LISTEN](https://www.postgresql.org/docs/current/sql-listen.html)/[NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) functionality, can be mode from `ConnectionPool`. - `QueryResult`: represents list of results from database. - `SingleQueryResult`: represents single result from the database. - `Exceptions`: we have some custom exceptions. diff --git a/docs/components/connection_pool.md b/docs/components/connection_pool.md index 15369612..bcee4893 100644 --- a/docs/components/connection_pool.md +++ b/docs/components/connection_pool.md @@ -254,6 +254,18 @@ This is the preferable way to work with the PostgreSQL. ::: +### Listener + +Create a new instance of a listener. + +```python +async def main() -> None: + ... + listener = db_pool.listener() +``` +``` + + ### Close To close the connection pool at the stop of your application. diff --git a/docs/components/exceptions.md b/docs/components/exceptions.md index aac3ecbd..fd2d86fb 100644 --- a/docs/components/exceptions.md +++ b/docs/components/exceptions.md @@ -15,6 +15,7 @@ stateDiagram-v2 RustPSQLDriverPyBaseError --> BaseConnectionError RustPSQLDriverPyBaseError --> BaseTransactionError RustPSQLDriverPyBaseError --> BaseCursorError + RustPSQLDriverPyBaseError --> BaseListenerError RustPSQLDriverPyBaseError --> RustException RustPSQLDriverPyBaseError --> RustToPyValueMappingError RustPSQLDriverPyBaseError --> PyToRustValueMappingError @@ -44,6 +45,11 @@ stateDiagram-v2 [*] --> CursorFetchError [*] --> CursorClosedError } + state BaseListenerError { + [*] --> ListenerStartError + [*] --> ListenerClosedError + [*] --> ListenerCallbackError + } state RustException { [*] --> DriverError [*] --> MacAddrParseError @@ -127,3 +133,15 @@ Error in cursor fetch (any fetch). #### CursorClosedError Error if underlying connection is closed. + +### BaseListenerError +Base error for all Listener errors. + +#### ListenerStartError +Error if listener start failed. + +#### ListenerClosedError +Error if listener manipulated but it's closed + +#### ListenerCallbackError +Error if callback passed to listener isn't a coroutine diff --git a/docs/components/listener.md b/docs/components/listener.md new file mode 100644 index 00000000..f2ecff30 --- /dev/null +++ b/docs/components/listener.md @@ -0,0 +1,204 @@ +--- +title: Listener +--- + +`Listener` object allows users to work with [LISTEN](https://www.postgresql.org/docs/current/sql-listen.html)/[NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) functionality. + +## Usage + +There are two ways of using `Listener` object: +- Async iterator +- Background task + +::: tabs + +@tab Background task +```python +from psqlpy import ConnectionPool, Connection, Listener + + +db_pool = ConnectionPool( + dsn="postgres://postgres:postgres@localhost:5432/postgres", +) + +async def test_channel_callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, +) -> None: + # do some important staff + ... + +async def main() -> None: + # Create listener object + listener: Listener = db_pool.listener() + + # Add channel to listen and callback for it. + await listener.add_callback( + channel="test_channel", + callback=test_channel_callback, + ) + + # Startup the listener + await listener.startup() + + # Start listening. + # `listen` method isn't blocking, it returns None and starts background + # task in the Rust event loop. + listener.listen() + + # You can stop listening. + listener.abort_listen() +``` + +@tab Async Iterator +```python +from psqlpy import ( + ConnectionPool, + Connection, + Listener, + ListenerNotificationMsg, +) + + +db_pool = ConnectionPool( + dsn="postgres://postgres:postgres@localhost:5432/postgres", +) + +async def main() -> None: + # Create listener object + listener: Listener = db_pool.listener() + + # Startup the listener + await listener.startup() + + listener_msg: ListenerNotificationMsg + async for listener_msg in listener: + print(listener_msg) +``` + +::: + +## Listener attributes + +- `connection`: Instance of `Connection`. +If `startup` wasn't called, raises `ListenerStartError`. + +- `is_started`: Flag that shows whether the `Listener` is running or not. + +## Listener methods + +### Startup + +Startup `Listener` instance and can be called once or again only after `shutdown`. + +::: important +`Listener` must be started up. +::: + +```python +async def main() -> None: + listener: Listener = db_pool.listener() + + await listener.startup() +``` + +### Shutdown +Abort listen (if called) and release underlying connection. + +```python +async def main() -> None: + listener: Listener = db_pool.listener() + + await listener.startup() + await listener.shutdown() +``` + +### Add Callback + +#### Parameters: +- `channel`: name of the channel to listen. +- `callback`: coroutine callback. + +Add new callback to the channel, can be called more than 1 times. + +Callback signature is like this: +```python +from psqlpy import Connection + +async def callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, +) -> None: + ... +``` + +Parameters for callback are based like `args`, so this signature is correct to: +```python +async def callback( + connection: Connection, + *args, +) -> None: + ... +``` + +**Example:** +```python +async def test_channel_callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, +) -> None: + ... + +async def main() -> None: + listener = db_pool.listener() + + await listener.add_callback( + channel="test_channel", + callback=test_channel_callback, + ) +``` + +### Clear Channel Callbacks + +#### Parameters: +- `channel`: name of the channel + +Remove all callbacks for the channel + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.clear_channel_callbacks() +``` + +### Clear All Channels +Clear all channels and callbacks. + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.clear_all_channels() +``` + +### Listen +Start listening. + +It's a non-blocking operation. +In the background it creates task in Rust event loop. + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.startup() + await listener.listen() +``` + +### Abort Listen +Abort listen. +If `listen()` method was called, stop listening, else don't do anything. diff --git a/python/psqlpy/__init__.py b/python/psqlpy/__init__.py index 0fb00d10..6f899719 100644 --- a/python/psqlpy/__init__.py +++ b/python/psqlpy/__init__.py @@ -6,6 +6,8 @@ Cursor, IsolationLevel, KeepaliveConfig, + Listener, + ListenerNotificationMsg, LoadBalanceHosts, QueryResult, ReadVariant, @@ -25,6 +27,8 @@ "Cursor", "IsolationLevel", "KeepaliveConfig", + "Listener", + "ListenerNotificationMsg", "LoadBalanceHosts", "QueryResult", "ReadVariant", diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index ff960651..42b836b2 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -2,7 +2,7 @@ import types from enum import Enum from io import BytesIO from ipaddress import IPv4Address, IPv6Address -from typing import Any, Callable, Sequence, TypeVar +from typing import Any, Awaitable, Callable, Sequence, TypeVar from typing_extensions import Buffer, Self @@ -1360,6 +1360,9 @@ class ConnectionPool: res = await connection.execute(...) ``` """ + def listener(self: Self) -> Listener: + """Create new listener.""" + def close(self: Self) -> None: """Close the connection pool.""" @@ -1748,3 +1751,136 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + +class Listener: + """Listener for LISTEN command. + + Can be used two ways: + 1) As a background task + 2) As an asynchronous iterator + + ## Examples + + ### Background task: + + ```python + async def callback( + channel: str, + payload: str, + process_id: int, + connection: Connection, + ) -> None: ... + async def main(): + pool = ConnectionPool() + + listener = pool.listener() + await listener.add_callback( + channel="test_channel", + callback=callback, + ) + await listener.startup() + + listener.listen() + ``` + + ### Async iterator + ```python + from psqlpy import + + async def msg_processor( + msg: ListenerNotificationMsg, + ) -> None: + ... + + + async def main(): + pool = ConnectionPool() + + listener = pool.listener() + await listener.add_callback( + channel="test_channel", + callback=callback, + ) + await listener.startup() + + for msg in listener: + await msg_processor(msg) + ``` + """ + + connection: Connection + is_started: bool + + def __aiter__(self: Self) -> Self: ... + async def __anext__(self: Self) -> ListenerNotificationMsg: ... + async def __aenter__(self: Self) -> Self: ... + async def __aexit__( + self: Self, + exception_type: type[BaseException] | None, + exception: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: ... + async def startup(self: Self) -> None: + """Startup the listener. + + Each listener MUST be started up. + """ + async def shutdown(self: Self) -> None: + """Shutdown the listener. + + Abort listen and release underlying connection. + """ + async def add_callback( + self: Self, + channel: str, + callback: Callable[ + [Connection, str, str, int], + Awaitable[None], + ], + ) -> None: + """Add callback to the channel. + + Callback must be async function and have signature like this: + ```python + async def callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, + ) -> None: ... + ``` + + Callback parameters are passed as args on the Rust side. + """ + + async def clear_channel_callbacks(self, channel: str) -> None: + """Remove all callbacks for the channel. + + ### Parameters: + - `channel`: name of the channel. + """ + + async def clear_all_channels(self) -> None: + """Clear all channels callbacks.""" + + def listen(self: Self) -> None: + """Start listening. + + Start actual listening. + In the background it creates task in Rust event loop. + """ + + def abort_listen(self: Self) -> None: + """Abort listen. + + If `listen()` method was called, stop listening, + else don't do anything. + """ + +class ListenerNotificationMsg: + """Listener message in async iterator.""" + + process_id: int + channel: str + payload: str + connection: Connection diff --git a/python/psqlpy/_internal/exceptions.pyi b/python/psqlpy/_internal/exceptions.pyi index bad0fc8c..a0588e9f 100644 --- a/python/psqlpy/_internal/exceptions.pyi +++ b/python/psqlpy/_internal/exceptions.pyi @@ -79,3 +79,15 @@ class PyToRustValueMappingError(RustPSQLDriverPyBaseError): You can get this exception when executing queries with parameters. So, if there are no parameters for the query, don't handle this error. """ + +class BaseListenerError(RustPSQLDriverPyBaseError): + """Base error for all Listener errors.""" + +class ListenerStartError(BaseListenerError): + """Error if listener start failed.""" + +class ListenerClosedError(BaseListenerError): + """Error if listener manipulated but it's closed.""" + +class ListenerCallbackError(BaseListenerError): + """Error if callback passed to listener isn't a coroutine.""" diff --git a/python/psqlpy/exceptions.py b/python/psqlpy/exceptions.py index c8240b9e..2d981ef3 100644 --- a/python/psqlpy/exceptions.py +++ b/python/psqlpy/exceptions.py @@ -2,6 +2,7 @@ BaseConnectionError, BaseConnectionPoolError, BaseCursorError, + BaseListenerError, BaseTransactionError, ConnectionClosedError, ConnectionExecuteError, @@ -12,6 +13,9 @@ CursorCloseError, CursorFetchError, CursorStartError, + ListenerCallbackError, + ListenerClosedError, + ListenerStartError, MacAddrConversionError, PyToRustValueMappingError, RustPSQLDriverPyBaseError, @@ -29,6 +33,7 @@ "BaseConnectionError", "BaseConnectionPoolError", "BaseCursorError", + "BaseListenerError", "BaseTransactionError", "ConnectionClosedError", "ConnectionExecuteError", @@ -39,6 +44,9 @@ "CursorClosedError", "CursorFetchError", "CursorStartError", + "ListenerCallbackError", + "ListenerClosedError", + "ListenerStartError", "MacAddrConversionError", "PyToRustValueMappingError", "RustPSQLDriverPyBaseError", diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 515ff87d..bfa3f650 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -68,6 +68,11 @@ def table_name() -> str: return random_string() +@pytest.fixture +def listener_table_name() -> str: + return random_string() + + @pytest.fixture def number_database_records() -> int: return random.randint(10, 35) @@ -137,6 +142,23 @@ async def create_default_data_for_tests( ) +@pytest.fixture +async def create_table_for_listener_tests( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> AsyncGenerator[None, None]: + await psql_pool.execute( + f"CREATE TABLE {listener_table_name}" + f"(id SERIAL, payload VARCHAR(255)," + f"channel VARCHAR(255), process_id INT)", + ) + + yield + await psql_pool.execute( + f"DROP TABLE {listener_table_name}", + ) + + @pytest.fixture async def test_cursor( psql_pool: ConnectionPool, diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py new file mode 100644 index 00000000..46722ca1 --- /dev/null +++ b/python/tests/test_listener.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import asyncio +import typing + +import pytest +from psqlpy.exceptions import ListenerStartError + +if typing.TYPE_CHECKING: + from psqlpy import Connection, ConnectionPool, Listener + +pytestmark = pytest.mark.anyio + + +TEST_CHANNEL = "test_channel" +TEST_PAYLOAD = "test_payload" + + +async def construct_listener( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> Listener: + listener = psql_pool.listener() + await listener.add_callback( + channel=TEST_CHANNEL, + callback=construct_insert_callback( + listener_table_name=listener_table_name, + ), + ) + return listener + + +def construct_insert_callback( + listener_table_name: str, +) -> typing.Callable[ + [Connection, str, str, int], + typing.Awaitable[None], +]: + async def callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, + ) -> None: + await connection.execute( + querystring=f"INSERT INTO {listener_table_name} VALUES (1, $1, $2, $3)", + parameters=( + payload, + channel, + process_id, + ), + ) + + return callback + + +async def notify( + psql_pool: ConnectionPool, + channel: str = TEST_CHANNEL, + with_delay: bool = False, +) -> None: + if with_delay: + await asyncio.sleep(0.5) + + await psql_pool.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") + + +async def check_insert_callback( + psql_pool: ConnectionPool, + listener_table_name: str, + is_insert_exist: bool = True, + number_of_data: int = 1, +) -> None: + test_data_seq = ( + await psql_pool.execute( + f"SELECT * FROM {listener_table_name}", + ) + ).result() + + if is_insert_exist: + assert len(test_data_seq) == number_of_data + else: + assert not len(test_data_seq) + return + + data_record = test_data_seq[0] + + assert data_record["payload"] == TEST_PAYLOAD + assert data_record["channel"] == TEST_CHANNEL + + +async def clear_test_table( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + await psql_pool.execute( + f"DELETE FROM {listener_table_name}", + ) + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_listen( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + """Test that single connection can execute queries.""" + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + listener.listen() + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_asynciterator( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + + asyncio.create_task( # noqa: RUF006 + notify( + psql_pool=psql_pool, + with_delay=True, + ), + ) + + async for listener_msg in listener: + assert listener_msg.channel == TEST_CHANNEL + assert listener_msg.payload == TEST_PAYLOAD + break + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_abort( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + listener.listen() + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + listener.abort_listen() + + await clear_test_table( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + is_insert_exist=False, + ) + + +async def test_listener_start_exc( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + with pytest.raises(expected_exception=ListenerStartError): + listener.listen() + + +async def test_listener_double_start_exc( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + + with pytest.raises(expected_exception=ListenerStartError): + await listener.startup() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_more_than_one_callback( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + additional_channel = "test_channel_2" + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.add_callback( + channel=additional_channel, + callback=construct_insert_callback( + listener_table_name=listener_table_name, + ), + ) + await listener.startup() + listener.listen() + + for channel in [TEST_CHANNEL, additional_channel]: + await notify( + psql_pool=psql_pool, + channel=channel, + ) + + await asyncio.sleep(0.5) + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + number_of_data=2, + ) + + query_result = await psql_pool.execute( + querystring=(f"SELECT * FROM {listener_table_name} WHERE channel = $1"), + parameters=(additional_channel,), + ) + + data_result = query_result.result()[0] + + assert data_result["channel"] == additional_channel + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_clear_callbacks( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + await listener.startup() + listener.listen() + + await listener.clear_channel_callbacks( + channel=TEST_CHANNEL, + ) + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + is_insert_exist=False, + ) + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_clear_all_callbacks( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + await listener.startup() + listener.listen() + + await listener.clear_all_channels() + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + is_insert_exist=False, + ) + + listener.abort_listen() diff --git a/src/common.rs b/src/common.rs index 0a4382ec..8dc70fc3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,10 +1,10 @@ -use deadpool_postgres::Object; use pyo3::{ types::{PyAnyMethods, PyModule, PyModuleMethods}, Bound, PyAny, PyResult, Python, }; use crate::{ + driver::connection::PsqlpyConnection, exceptions::rust_errors::RustPSQLDriverPyResult, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, value_converter::{convert_parameters, PythonDTO, QueryParameter}, @@ -24,10 +24,10 @@ pub fn add_module( child_mod_name: &'static str, child_mod_builder: impl FnOnce(Python<'_>, &Bound<'_, PyModule>) -> PyResult<()>, ) -> PyResult<()> { - let sub_module = PyModule::new_bound(py, child_mod_name)?; + let sub_module = PyModule::new(py, child_mod_name)?; child_mod_builder(py, &sub_module)?; parent_mod.add_submodule(&sub_module)?; - py.import_bound("sys")?.getattr("modules")?.set_item( + py.import("sys")?.getattr("modules")?.set_item( format!("{}.{}", parent_mod.name()?, child_mod_name), sub_module, )?; @@ -55,7 +55,7 @@ pub trait ObjectQueryTrait { ) -> impl std::future::Future> + Send; } -impl ObjectQueryTrait for Object { +impl ObjectQueryTrait for PsqlpyConnection { async fn psqlpy_query_one( &self, querystring: String, @@ -131,6 +131,6 @@ impl ObjectQueryTrait for Object { } async fn psqlpy_query_simple(&self, querystring: String) -> RustPSQLDriverPyResult<()> { - Ok(self.batch_execute(querystring.as_str()).await?) + self.batch_execute(querystring.as_str()).await } } diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 97dc66a1..f10328c2 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -1,9 +1,12 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use deadpool_postgres::{Object, Pool}; use futures_util::pin_mut; +use postgres_types::ToSql; use pyo3::{buffer::PyBuffer, pyclass, pymethods, Py, PyAny, PyErr, Python}; use std::{collections::HashSet, sync::Arc, vec}; -use tokio_postgres::binary_copy::BinaryCopyInWriter; +use tokio_postgres::{ + binary_copy::BinaryCopyInWriter, Client, CopyInSink, Row, Statement, ToStatement, +}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -19,110 +22,121 @@ use super::{ transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; -/// Format OPTS parameter for Postgres COPY command. -/// -/// # Errors -/// May return Err Result if cannot format parameter. -#[allow(clippy::too_many_arguments)] -pub fn _format_copy_opts( - format: Option, - freeze: Option, - delimiter: Option, - null: Option, - header: Option, - quote: Option, - escape: Option, - force_quote: Option>, - force_not_null: Option>, - force_null: Option>, - encoding: Option, -) -> RustPSQLDriverPyResult { - let mut opts: Vec = vec![]; - - if let Some(format) = format { - opts.push(format!("FORMAT {format}")); - } +#[allow(clippy::module_name_repetitions)] +pub enum PsqlpyConnection { + PoolConn(Object), + SingleConn(Client), +} - if let Some(freeze) = freeze { - if freeze { - opts.push("FREEZE TRUE".into()); - } else { - opts.push("FREEZE FALSE".into()); +impl PsqlpyConnection { + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot prepare statement. + pub async fn prepare_cached(&self, query: &str) -> RustPSQLDriverPyResult { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), } } - if let Some(delimiter) = delimiter { - opts.push(format!("DELIMITER {delimiter}")); - } - - if let Some(null) = null { - opts.push(format!("NULL {}", quote_ident(&null))); - } - - if let Some(header) = header { - opts.push(format!("HEADER {header}")); - } - - if let Some(quote) = quote { - opts.push(format!("QUOTE {quote}")); + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> RustPSQLDriverPyResult> + where + T: ?Sized + ToStatement, + { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.query(statement, params).await?), + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.query(statement, params).await?) + } + } } - if let Some(escape) = escape { - opts.push(format!("ESCAPE {escape}")); + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), + } } - if let Some(force_quote) = force_quote { - let boolean_force_quote: Result = - Python::with_gil(|gil| force_quote.extract::(gil)); - - if let Ok(force_quote) = boolean_force_quote { - if force_quote { - opts.push("FORCE_QUOTE *".into()); + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> RustPSQLDriverPyResult + where + T: ?Sized + ToStatement, + { + match self { + PsqlpyConnection::PoolConn(pconn) => { + return Ok(pconn.query_one(statement, params).await?) } - } else { - let sequence_force_quote: Result, PyErr> = - Python::with_gil(|gil| force_quote.extract::>(gil)); - - if let Ok(force_quote) = sequence_force_quote { - opts.push(format!("FORCE_QUOTE ({})", force_quote.join(", "))); + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.query_one(statement, params).await?) } - - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "force_quote parameter must be boolean or sequence of str's.".into(), - )); } } - if let Some(force_not_null) = force_not_null { - opts.push(format!("FORCE_NOT_NULL ({})", force_not_null.join(", "))); - } - - if let Some(force_null) = force_null { - opts.push(format!("FORCE_NULL ({})", force_null.join(", "))); - } - - if let Some(encoding) = encoding { - opts.push(format!("ENCODING {}", quote_ident(&encoding))); - } - - if opts.is_empty() { - Ok(String::new()) - } else { - Ok(format!("({})", opts.join(", "))) + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute copy data. + pub async fn copy_in(&self, statement: &T) -> RustPSQLDriverPyResult> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send, + { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.copy_in(statement).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), + } } } #[pyclass(subclass)] +#[derive(Clone)] pub struct Connection { - db_client: Option>, + db_client: Option>, db_pool: Option, } impl Connection { #[must_use] - pub fn new(db_client: Option>, db_pool: Option) -> Self { + pub fn new(db_client: Option>, db_pool: Option) -> Self { Connection { db_client, db_pool } } + + #[must_use] + pub fn db_client(&self) -> Option> { + self.db_client.clone() + } + + #[must_use] + pub fn db_pool(&self) -> Option { + self.db_pool.clone() + } +} + +impl Default for Connection { + fn default() -> Self { + Connection::new(None, None) + } } #[pymethods] @@ -145,7 +159,7 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(db_connection)); + self_.db_client = Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))); }); return Ok(self_); } @@ -163,7 +177,7 @@ impl Connection { let (is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { ( exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), ) }); @@ -271,7 +285,7 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return Ok(db_client.batch_execute(&querystring).await?); + return db_client.batch_execute(&querystring).await; } Err(RustPSQLDriverError::ConnectionClosedError) diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 8f5ea984..4f38407d 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,10 +1,8 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; -use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; -use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; use std::{sync::Arc, vec}; -use tokio_postgres::NoTls; +use tokio_postgres::Config; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -13,9 +11,10 @@ use crate::{ }; use super::{ - common_options::{self, ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, - connection::Connection, - utils::build_connection_config, + common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, + connection::{Connection, PsqlpyConnection}, + listener::core::Listener, + utils::{build_connection_config, build_manager, build_tls}, }; /// Make new connection pool. @@ -77,7 +76,6 @@ pub fn connect( load_balance_hosts: Option, ssl_mode: Option, ca_file: Option, - max_db_pool_size: Option, conn_recycling_method: Option, ) -> RustPSQLDriverPyResult { @@ -126,33 +124,25 @@ pub fn connect( }; } - let mgr: Manager; - if let Some(ca_file) = ca_file { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_ca_file(ca_file)?; - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(pg_config, tls_connector, mgr_config); - } else if let Some(ssl_mode) = ssl_mode { - if ssl_mode == common_options::SslMode::Require { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_verify(SslVerifyMode::NONE); - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(pg_config, tls_connector, mgr_config); - } else { - mgr = Manager::from_config(pg_config, NoTls, mgr_config); - } - } else { - mgr = Manager::from_config(pg_config, NoTls, mgr_config); - } + let mgr: Manager = build_manager( + mgr_config, + pg_config.clone(), + build_tls(&ca_file, &ssl_mode)?, + ); let mut db_pool_builder = Pool::builder(mgr); if let Some(max_db_pool_size) = max_db_pool_size { db_pool_builder = db_pool_builder.max_size(max_db_pool_size); } - let db_pool = db_pool_builder.build()?; + let pool = db_pool_builder.build()?; - Ok(ConnectionPool(db_pool)) + Ok(ConnectionPool { + pool, + pg_config, + ca_file, + ssl_mode, + }) } #[pyclass] @@ -212,8 +202,32 @@ impl ConnectionPoolStatus { } } +// #[pyclass(subclass)] +// pub struct ConnectionPool(pub Pool); #[pyclass(subclass)] -pub struct ConnectionPool(pub Pool); +pub struct ConnectionPool { + pool: Pool, + pg_config: Config, + ca_file: Option, + ssl_mode: Option, +} + +impl ConnectionPool { + #[must_use] + pub fn build( + pool: Pool, + pg_config: Config, + ca_file: Option, + ssl_mode: Option, + ) -> Self { + ConnectionPool { + pool, + pg_config, + ca_file, + ssl_mode, + } + } +} #[pymethods] impl ConnectionPool { @@ -333,7 +347,7 @@ impl ConnectionPool { #[must_use] pub fn status(&self) -> ConnectionPoolStatus { - let inner_status = self.0.status(); + let inner_status = self.pool.status(); ConnectionPoolStatus::new( inner_status.max_size, @@ -344,7 +358,7 @@ impl ConnectionPool { } pub fn resize(&self, new_max_size: usize) { - self.0.resize(new_max_size); + self.pool.resize(new_max_size); } /// Execute querystring with parameters. @@ -361,7 +375,7 @@ impl ConnectionPool { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone()); + let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); let db_pool_manager = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) }) @@ -430,7 +444,7 @@ impl ConnectionPool { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone()); + let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); let db_pool_manager = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) }) @@ -484,7 +498,22 @@ impl ConnectionPool { #[must_use] pub fn acquire(&self) -> Connection { - Connection::new(None, Some(self.0.clone())) + Connection::new(None, Some(self.pool.clone())) + } + + #[must_use] + #[allow(clippy::needless_pass_by_value)] + pub fn listener(self_: pyo3::Py) -> Listener { + let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { + let b_gil = self_.borrow(gil); + ( + b_gil.pg_config.clone(), + b_gil.ca_file.clone(), + b_gil.ssl_mode, + ) + }); + + Listener::new(pg_config, ca_file, ssl_mode) } /// Return new single connection. @@ -492,14 +521,17 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub async fn connection(self_: pyo3::Py) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone()); + let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); let db_connection = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) }) .await??; - Ok(Connection::new(Some(Arc::new(db_connection)), None)) + Ok(Connection::new( + Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))), + None, + )) } /// Close connection pool. @@ -507,7 +539,7 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub fn close(&self) { - let db_pool = self.0.clone(); + let db_pool = self.pool.clone(); db_pool.close(); } diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index bb2047b4..e0610942 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -1,14 +1,15 @@ use std::{net::IpAddr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; -use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; -use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pymethods, Py, Python}; -use tokio_postgres::NoTls; use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; -use super::{common_options, connection_pool::ConnectionPool}; +use super::{ + common_options, + connection_pool::ConnectionPool, + utils::{build_manager, build_tls}, +}; #[pyclass] pub struct ConnectionPoolBuilder { @@ -49,24 +50,11 @@ impl ConnectionPoolBuilder { }; }; - let mgr: Manager; - if let Some(ca_file) = &self.ca_file { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_ca_file(ca_file)?; - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config); - } else if let Some(ssl_mode) = self.ssl_mode { - if ssl_mode == common_options::SslMode::Require { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_verify(SslVerifyMode::NONE); - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config); - } else { - mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config); - } - } else { - mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config); - } + let mgr: Manager = build_manager( + mgr_config, + self.config.clone(), + build_tls(&self.ca_file, &self.ssl_mode)?, + ); let mut db_pool_builder = Pool::builder(mgr); if let Some(max_db_pool_size) = self.max_db_pool_size { @@ -75,7 +63,12 @@ impl ConnectionPoolBuilder { let db_pool = db_pool_builder.build()?; - Ok(ConnectionPool(db_pool)) + Ok(ConnectionPool::build( + db_pool, + self.config.clone(), + self.ca_file.clone(), + self.ssl_mode, + )) } /// Set ca_file for ssl_mode in PostgreSQL. diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 3f8008be..7368d29a 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use deadpool_postgres::Object; use pyo3::{ exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, }; @@ -12,6 +11,8 @@ use crate::{ runtime::rustdriver_future, }; +use super::connection::PsqlpyConnection; + /// Additional implementation for the `Object` type. #[allow(clippy::ref_option)] trait CursorObjectTrait { @@ -27,7 +28,7 @@ trait CursorObjectTrait { async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; } -impl CursorObjectTrait for Object { +impl CursorObjectTrait for PsqlpyConnection { /// Start the cursor. /// /// Execute `DECLARE` command with parameters. @@ -89,7 +90,7 @@ impl CursorObjectTrait for Object { #[pyclass(subclass)] pub struct Cursor { - db_transaction: Option>, + db_transaction: Option>, querystring: String, parameters: Option>, cursor_name: String, @@ -103,7 +104,7 @@ pub struct Cursor { impl Cursor { #[must_use] pub fn new( - db_transaction: Arc, + db_transaction: Arc, querystring: String, parameters: Option>, cursor_name: String, @@ -178,7 +179,7 @@ impl Cursor { self_.closed, self_.cursor_name.clone(), exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), ) }); diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs new file mode 100644 index 00000000..c8fd271c --- /dev/null +++ b/src/driver/listener/core.rs @@ -0,0 +1,379 @@ +use std::sync::Arc; + +use futures::{stream, FutureExt, StreamExt, TryStreamExt}; +use futures_channel::mpsc::UnboundedReceiver; +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; +use postgres_openssl::MakeTlsConnector; +use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, Python}; +use tokio::{ + sync::RwLock, + task::{AbortHandle, JoinHandle}, +}; +use tokio_postgres::{AsyncMessage, Config}; + +use crate::{ + driver::{ + common_options::SslMode, + connection::{Connection, PsqlpyConnection}, + utils::{build_tls, is_coroutine_function, ConfiguredTLS}, + }, + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + runtime::{rustdriver_future, tokio_runtime}, +}; + +use super::structs::{ + ChannelCallbacks, ListenerCallback, ListenerNotification, ListenerNotificationMsg, +}; + +#[pyclass] +pub struct Listener { + pg_config: Config, + ca_file: Option, + ssl_mode: Option, + channel_callbacks: Arc>, + listen_abort_handler: Option, + connection: Connection, + receiver: Option>>>, + listen_query: Arc>, + is_listened: Arc>, + is_started: bool, +} + +impl Listener { + #[must_use] + pub fn new(pg_config: Config, ca_file: Option, ssl_mode: Option) -> Self { + Listener { + pg_config, + ca_file, + ssl_mode, + channel_callbacks: Arc::default(), + listen_abort_handler: Option::default(), + connection: Connection::new(None, None), + receiver: Option::default(), + listen_query: Arc::default(), + is_listened: Arc::new(RwLock::new(false)), + is_started: false, + } + } + + async fn update_listen_query(&self) { + let read_channel_callbacks = self.channel_callbacks.read().await; + + let channels = read_channel_callbacks.retrieve_all_channels(); + + let mut final_query: String = String::default(); + + for channel_name in channels { + final_query.push_str(format!("LISTEN {channel_name};").as_str()); + } + + let mut write_listen_query = self.listen_query.write().await; + let mut write_is_listened = self.is_listened.write().await; + + write_listen_query.clear(); + write_listen_query.push_str(&final_query); + *write_is_listened = false; + } +} + +#[pymethods] +impl Listener { + #[must_use] + fn __aiter__(slf: Py) -> Py { + slf + } + + fn __await__(slf: Py) -> Py { + slf + } + + #[allow(clippy::unused_async)] + async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { + Ok(slf) + } + + #[allow(clippy::unused_async)] + async fn __aexit__<'a>( + slf: Py, + _exception_type: Py, + exception: Py, + _traceback: Py, + ) -> RustPSQLDriverPyResult<()> { + let (client, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { + let self_ = slf.borrow(gil); + ( + self_.connection.db_client(), + exception.is_none(gil), + PyErr::from_value(exception.into_bound(gil)), + ) + }); + + if client.is_some() { + pyo3::Python::with_gil(|gil| { + let mut self_ = slf.borrow_mut(gil); + std::mem::take(&mut self_.connection); + std::mem::take(&mut self_.receiver); + }); + + if !is_exception_none { + return Err(RustPSQLDriverError::RustPyError(py_err)); + } + + return Ok(()); + } + + Err(RustPSQLDriverError::ListenerClosedError) + } + + fn __anext__(&self) -> RustPSQLDriverPyResult>> { + let Some(client) = self.connection.db_client() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener doesn't have underlying client, please call startup".into(), + )); + }; + let Some(receiver) = self.receiver.clone() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener doesn't have underlying receiver, please call startup".into(), + )); + }; + + let is_listened_clone = self.is_listened.clone(); + let listen_query_clone = self.listen_query.clone(); + let connection = self.connection.clone(); + + let py_future = Python::with_gil(move |gil| { + rustdriver_future(gil, async move { + { + execute_listen(&is_listened_clone, &listen_query_clone, &client).await?; + }; + let next_element = { + let mut write_receiver = receiver.write().await; + write_receiver.next().await + }; + + let inner_notification = process_message(next_element)?; + + Ok(ListenerNotificationMsg::new(inner_notification, connection)) + }) + }); + + Ok(Some(py_future?)) + } + + #[getter] + fn is_started(&self) -> bool { + self.is_started + } + + #[getter] + fn connection(&self) -> RustPSQLDriverPyResult { + if !self.is_started { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener isn't started up".into(), + )); + } + + Ok(self.connection.clone()) + } + + async fn startup(&mut self) -> RustPSQLDriverPyResult<()> { + if self.is_started { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener is already started".into(), + )); + } + + let tls_ = build_tls(&self.ca_file, &self.ssl_mode)?; + + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_verify(SslVerifyMode::NONE); + + let pg_config = self.pg_config.clone(); + let connect_future = async move { + match tls_ { + ConfiguredTLS::NoTls => { + return pg_config + .connect(MakeTlsConnector::new(builder.build())) + .await; + } + ConfiguredTLS::TlsConnector(connector) => { + return pg_config.connect(connector).await; + } + } + }; + + let (client, mut connection) = tokio_runtime().spawn(connect_future).await??; + + let (transmitter, receiver) = futures_channel::mpsc::unbounded::(); + + let stream = + stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); + + let connection = stream.forward(transmitter).map(|r| { + r.map_err(|_| { + RustPSQLDriverError::ListenerStartError("Cannot startup the listener".into()) + }) + }); + tokio_runtime().spawn(connection); + + self.receiver = Some(Arc::new(RwLock::new(receiver))); + self.connection = + Connection::new(Some(Arc::new(PsqlpyConnection::SingleConn(client))), None); + + self.is_started = true; + + Ok(()) + } + + async fn shutdown(&mut self) { + self.abort_listen(); + std::mem::take(&mut self.connection); + std::mem::take(&mut self.receiver); + + self.is_started = false; + } + + #[pyo3(signature = (channel, callback))] + async fn add_callback( + &mut self, + channel: String, + callback: Py, + ) -> RustPSQLDriverPyResult<()> { + if !is_coroutine_function(callback.clone())? { + return Err(RustPSQLDriverError::ListenerCallbackError); + } + + let task_locals = Python::with_gil(pyo3_async_runtimes::tokio::get_current_locals)?; + + let listener_callback = ListenerCallback::new(task_locals, callback); + + { + let mut write_channel_callbacks = self.channel_callbacks.write().await; + write_channel_callbacks.add_callback(channel, listener_callback); + } + + self.update_listen_query().await; + + Ok(()) + } + + async fn clear_channel_callbacks(&mut self, channel: String) { + { + let mut write_channel_callbacks = self.channel_callbacks.write().await; + write_channel_callbacks.clear_channel_callbacks(&channel); + } + + self.update_listen_query().await; + } + + async fn clear_all_channels(&mut self) { + { + let mut write_channel_callbacks = self.channel_callbacks.write().await; + write_channel_callbacks.clear_all(); + } + + self.update_listen_query().await; + } + + fn listen(&mut self) -> RustPSQLDriverPyResult<()> { + let Some(client) = self.connection.db_client() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Cannot start listening, underlying connection doesn't exist".into(), + )); + }; + let Some(receiver) = self.receiver.clone() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Cannot start listening, underlying connection doesn't exist".into(), + )); + }; + + let connection = self.connection.clone(); + let listen_query_clone = self.listen_query.clone(); + let is_listened_clone = self.is_listened.clone(); + + let channel_callbacks = self.channel_callbacks.clone(); + + let jh: JoinHandle> = tokio_runtime().spawn(async move { + loop { + { + execute_listen(&is_listened_clone, &listen_query_clone, &client).await?; + }; + + let next_element = { + let mut write_receiver = receiver.write().await; + write_receiver.next().await + }; + + let inner_notification = process_message(next_element)?; + + let read_channel_callbacks = channel_callbacks.read().await; + let channel = inner_notification.channel.clone(); + let callbacks = read_channel_callbacks.retrieve_channel_callbacks(&channel); + + if let Some(callbacks) = callbacks { + for callback in callbacks { + dispatch_callback(callback, inner_notification.clone(), connection.clone()) + .await?; + } + } + } + }); + + let abj = jh.abort_handle(); + + self.listen_abort_handler = Some(abj); + + Ok(()) + } + + fn abort_listen(&mut self) { + if let Some(listen_abort_handler) = &self.listen_abort_handler { + listen_abort_handler.abort(); + } + + self.listen_abort_handler = None; + } +} + +async fn dispatch_callback( + listener_callback: &ListenerCallback, + listener_notification: ListenerNotification, + connection: Connection, +) -> RustPSQLDriverPyResult<()> { + listener_callback + .call(listener_notification.clone(), connection) + .await?; + + Ok(()) +} + +async fn execute_listen( + is_listened: &Arc>, + listen_query: &Arc>, + client: &Arc, +) -> RustPSQLDriverPyResult<()> { + let mut write_is_listened = is_listened.write().await; + + if !write_is_listened.eq(&true) { + let listen_q = { + let read_listen_query = listen_query.read().await; + String::from(read_listen_query.as_str()) + }; + + client.batch_execute(listen_q.as_str()).await?; + } + + *write_is_listened = true; + Ok(()) +} + +fn process_message(message: Option) -> RustPSQLDriverPyResult { + let Some(async_message) = message else { + return Err(RustPSQLDriverError::ListenerError("Wow".into())); + }; + let AsyncMessage::Notification(notification) = async_message else { + return Err(RustPSQLDriverError::ListenerError("Wow".into())); + }; + + Ok(ListenerNotification::from(notification)) +} diff --git a/src/driver/listener/mod.rs b/src/driver/listener/mod.rs new file mode 100644 index 00000000..b67b7b0b --- /dev/null +++ b/src/driver/listener/mod.rs @@ -0,0 +1,2 @@ +pub mod core; +pub mod structs; diff --git a/src/driver/listener/structs.rs b/src/driver/listener/structs.rs new file mode 100644 index 00000000..4d53a408 --- /dev/null +++ b/src/driver/listener/structs.rs @@ -0,0 +1,162 @@ +use std::collections::{hash_map::Entry, HashMap}; + +use pyo3::{pyclass, pymethods, Py, PyAny, Python}; +use pyo3_async_runtimes::TaskLocals; +use tokio_postgres::Notification; + +use crate::{ + driver::connection::Connection, + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + runtime::tokio_runtime, +}; + +#[derive(Default)] +pub struct ChannelCallbacks(HashMap>); + +impl ChannelCallbacks { + pub fn add_callback(&mut self, channel: String, callback: ListenerCallback) { + match self.0.entry(channel) { + Entry::Vacant(e) => { + e.insert(vec![callback]); + } + Entry::Occupied(mut e) => { + e.get_mut().push(callback); + } + }; + } + + #[must_use] + pub fn retrieve_channel_callbacks(&self, channel: &str) -> Option<&Vec> { + self.0.get(channel) + } + + pub fn clear_channel_callbacks(&mut self, channel: &str) { + self.0.remove(channel); + } + + pub fn clear_all(&mut self) { + self.0.clear(); + } + + #[must_use] + pub fn retrieve_all_channels(&self) -> Vec<&String> { + self.0.keys().collect::>() + } +} + +#[derive(Clone, Debug)] +pub struct ListenerNotification { + pub process_id: i32, + pub channel: String, + pub payload: String, +} + +impl From for ListenerNotification { + fn from(value: Notification) -> Self { + ListenerNotification { + process_id: value.process_id(), + channel: String::from(value.channel()), + payload: String::from(value.payload()), + } + } +} + +#[pyclass] +pub struct ListenerNotificationMsg { + process_id: i32, + channel: String, + payload: String, + connection: Connection, +} + +#[pymethods] +impl ListenerNotificationMsg { + #[getter] + fn process_id(&self) -> i32 { + self.process_id + } + + #[getter] + fn channel(&self) -> String { + self.channel.clone() + } + + #[getter] + fn payload(&self) -> String { + self.payload.clone() + } + + #[getter] + fn connection(&self) -> Connection { + self.connection.clone() + } +} + +impl ListenerNotificationMsg { + #[must_use] + pub fn new(value: ListenerNotification, conn: Connection) -> Self { + ListenerNotificationMsg { + process_id: value.process_id, + channel: value.channel, + payload: value.payload, + connection: conn, + } + } +} + +pub struct ListenerCallback { + task_locals: TaskLocals, + callback: Py, +} + +impl ListenerCallback { + #[must_use] + pub fn new(task_locals: TaskLocals, callback: Py) -> Self { + ListenerCallback { + task_locals, + callback, + } + } + + /// Dispatch the callback. + /// + /// # Errors + /// May return Err Result if cannot call python future. + pub async fn call( + &self, + lister_notification: ListenerNotification, + connection: Connection, + ) -> RustPSQLDriverPyResult<()> { + let (callback, task_locals) = + Python::with_gil(|py| (self.callback.clone(), self.task_locals.clone_ref(py))); + + tokio_runtime() + .spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { + let future = Python::with_gil(|py| { + let awaitable = callback + .call1( + py, + ( + connection, + lister_notification.payload, + lister_notification.channel, + lister_notification.process_id, + ), + ) + .map_err(|_| RustPSQLDriverError::ListenerCallbackError)?; + let aba = pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py))?; + Ok(aba) + }); + Ok::, RustPSQLDriverError>( + future + .map_err(|_: RustPSQLDriverError| { + RustPSQLDriverError::ListenerCallbackError + })? + .await?, + ) + })) + .await??; + + Ok(()) + } +} diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 183a0343..578bf2cd 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -3,6 +3,7 @@ pub mod connection; pub mod connection_pool; pub mod connection_pool_builder; pub mod cursor; +pub mod listener; pub mod transaction; pub mod transaction_options; pub mod utils; diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index bc992c68..3fa59e4d 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -1,5 +1,4 @@ use bytes::BytesMut; -use deadpool_postgres::Object; use futures_util::{future, pin_mut}; use pyo3::{ buffer::PyBuffer, @@ -17,6 +16,7 @@ use crate::{ }; use super::{ + connection::PsqlpyConnection, cursor::Cursor, transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; @@ -36,7 +36,7 @@ pub trait TransactionObjectTrait { fn rollback(&self) -> impl std::future::Future> + Send; } -impl TransactionObjectTrait for Object { +impl TransactionObjectTrait for PsqlpyConnection { async fn start_transaction( &self, isolation_level: Option, @@ -106,7 +106,7 @@ impl TransactionObjectTrait for Object { #[pyclass(subclass)] pub struct Transaction { - pub db_client: Option>, + pub db_client: Option>, is_started: bool, is_done: bool, @@ -122,7 +122,7 @@ impl Transaction { #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( - db_client: Arc, + db_client: Arc, is_started: bool, is_done: bool, isolation_level: Option, @@ -236,7 +236,7 @@ impl Transaction { ( self_.check_is_transaction_ready(), exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), self_.db_client.clone(), ) }); @@ -355,7 +355,7 @@ impl Transaction { }); is_transaction_ready?; if let Some(db_client) = db_client { - return Ok(db_client.batch_execute(&querystring).await?); + return db_client.batch_execute(&querystring).await; } Err(RustPSQLDriverError::TransactionClosedError) diff --git a/src/driver/transaction_options.rs b/src/driver/transaction_options.rs index 51467761..281b9a71 100644 --- a/src/driver/transaction_options.rs +++ b/src/driver/transaction_options.rs @@ -71,3 +71,11 @@ impl SynchronousCommit { } } } + +#[derive(Clone, Copy, PartialEq)] +pub struct ListenerTransactionConfig { + isolation_level: Option, + read_variant: Option, + deferrable: Option, + synchronous_commit: Option, +} diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 33bf0aea..3d0d59e3 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -1,8 +1,14 @@ use std::{str::FromStr, time::Duration}; +use deadpool_postgres::{Manager, ManagerConfig}; +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; +use postgres_openssl::MakeTlsConnector; +use pyo3::{types::PyAnyMethods, Py, PyAny, Python}; +use tokio_postgres::{Config, NoTls}; + use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; -use super::common_options::{LoadBalanceHosts, SslMode, TargetSessionAttrs}; +use super::common_options::{self, LoadBalanceHosts, SslMode, TargetSessionAttrs}; /// Create new config. /// @@ -163,3 +169,71 @@ pub fn build_connection_config( Ok(pg_config) } + +pub enum ConfiguredTLS { + NoTls, + TlsConnector(MakeTlsConnector), +} + +/// Create TLS. +/// +/// # Errors +/// May return Err Result if cannot create builder. +pub fn build_tls( + ca_file: &Option, + ssl_mode: &Option, +) -> RustPSQLDriverPyResult { + if let Some(ca_file) = ca_file { + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_ca_file(ca_file)?; + return Ok(ConfiguredTLS::TlsConnector(MakeTlsConnector::new( + builder.build(), + ))); + } else if let Some(ssl_mode) = ssl_mode { + if *ssl_mode == common_options::SslMode::Require { + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_verify(SslVerifyMode::NONE); + return Ok(ConfiguredTLS::TlsConnector(MakeTlsConnector::new( + builder.build(), + ))); + } + } + + Ok(ConfiguredTLS::NoTls) +} + +#[must_use] +pub fn build_manager( + mgr_config: ManagerConfig, + pg_config: Config, + configured_tls: ConfiguredTLS, +) -> Manager { + let mgr: Manager = match configured_tls { + ConfiguredTLS::NoTls => Manager::from_config(pg_config, NoTls, mgr_config), + ConfiguredTLS::TlsConnector(connector) => { + Manager::from_config(pg_config, connector, mgr_config) + } + }; + + mgr +} + +/// Check is python object async or not. +/// +/// # Errors +/// May return Err Result if cannot +/// 1) import inspect +/// 2) extract boolean +pub fn is_coroutine_function(function: Py) -> RustPSQLDriverPyResult { + let is_coroutine_function: bool = Python::with_gil(|py| { + let inspect = py.import("inspect")?; + + let is_cor = inspect + .call_method1("iscoroutinefunction", (function,)) + .map_err(|_| RustPSQLDriverError::ListenerClosedError)? + .extract::()?; + Ok::(is_cor) + })?; + + Ok(is_coroutine_function) +} diff --git a/src/exceptions/python_errors.rs b/src/exceptions/python_errors.rs index bd5fc641..4d3798cb 100644 --- a/src/exceptions/python_errors.rs +++ b/src/exceptions/python_errors.rs @@ -3,6 +3,7 @@ use pyo3::{ types::{PyModule, PyModuleMethods}, Bound, PyResult, Python, }; + // Main exception. create_exception!( psqlpy.exceptions, @@ -105,6 +106,16 @@ create_exception!(psqlpy.exceptions, CursorCloseError, BaseCursorError); create_exception!(psqlpy.exceptions, CursorFetchError, BaseCursorError); create_exception!(psqlpy.exceptions, CursorClosedError, BaseCursorError); +// Listener Error +create_exception!( + psqlpy.exceptions, + BaseListenerError, + RustPSQLDriverPyBaseError +); +create_exception!(psqlpy.exceptions, ListenerStartError, BaseListenerError); +create_exception!(psqlpy.exceptions, ListenerClosedError, BaseListenerError); +create_exception!(psqlpy.exceptions, ListenerCallbackError, BaseListenerError); + // Inner exceptions create_exception!( psqlpy.exceptions, @@ -132,95 +143,97 @@ create_exception!( create_exception!(psqlpy.exceptions, SSLError, RustPSQLDriverPyBaseError); #[allow(clippy::missing_errors_doc)] +#[allow(clippy::too_many_lines)] pub fn python_exceptions_module(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add( "RustPSQLDriverPyBaseError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "BaseConnectionPoolError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionPoolBuildError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionPoolConfigurationError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionPoolExecuteError", - py.get_type_bound::(), + py.get_type::(), )?; - pymod.add( - "BaseConnectionError", - py.get_type_bound::(), - )?; + pymod.add("BaseConnectionError", py.get_type::())?; pymod.add( "ConnectionExecuteError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionClosedError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "BaseTransactionError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionBeginError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionCommitError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionRollbackError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionSavepointError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionExecuteError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionClosedError", - py.get_type_bound::(), + py.get_type::(), )?; - pymod.add("BaseCursorError", py.get_type_bound::())?; - pymod.add("CursorStartError", py.get_type_bound::())?; - pymod.add("CursorCloseError", py.get_type_bound::())?; - pymod.add("CursorFetchError", py.get_type_bound::())?; - pymod.add( - "CursorClosedError", - py.get_type_bound::(), - )?; + pymod.add("BaseCursorError", py.get_type::())?; + pymod.add("CursorStartError", py.get_type::())?; + pymod.add("CursorCloseError", py.get_type::())?; + pymod.add("CursorFetchError", py.get_type::())?; + pymod.add("CursorClosedError", py.get_type::())?; pymod.add( "RustToPyValueMappingError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "PyToRustValueMappingError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "UUIDValueConvertError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "MacAddrConversionError", - py.get_type_bound::(), + py.get_type::(), + )?; + pymod.add("BaseListenerError", py.get_type::())?; + pymod.add("ListenerStartError", py.get_type::())?; + pymod.add("ListenerClosedError", py.get_type::())?; + pymod.add( + "ListenerCallbackError", + py.get_type::(), )?; Ok(()) } diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index 72fb659c..48af50cb 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -5,13 +5,13 @@ use tokio::task::JoinError; use crate::exceptions::python_errors::{PyToRustValueMappingError, RustToPyValueMappingError}; use super::python_errors::{ - BaseConnectionError, BaseConnectionPoolError, BaseCursorError, BaseTransactionError, - ConnectionClosedError, ConnectionExecuteError, ConnectionPoolBuildError, + BaseConnectionError, BaseConnectionPoolError, BaseCursorError, BaseListenerError, + BaseTransactionError, ConnectionClosedError, ConnectionExecuteError, ConnectionPoolBuildError, ConnectionPoolConfigurationError, ConnectionPoolExecuteError, CursorCloseError, - CursorClosedError, CursorFetchError, CursorStartError, DriverError, MacAddrParseError, - RuntimeJoinError, SSLError, TransactionBeginError, TransactionClosedError, - TransactionCommitError, TransactionExecuteError, TransactionRollbackError, - TransactionSavepointError, UUIDValueConvertError, + CursorClosedError, CursorFetchError, CursorStartError, DriverError, ListenerCallbackError, + ListenerClosedError, ListenerStartError, MacAddrParseError, RuntimeJoinError, SSLError, + TransactionBeginError, TransactionClosedError, TransactionCommitError, TransactionExecuteError, + TransactionRollbackError, TransactionSavepointError, UUIDValueConvertError, }; pub type RustPSQLDriverPyResult = Result; @@ -64,6 +64,16 @@ pub enum RustPSQLDriverError { #[error("Underlying connection is returned to the pool")] CursorClosedError, + // Listener Errors + #[error("Listener error: {0}")] + ListenerError(String), + #[error("Listener start error: {0}")] + ListenerStartError(String), + #[error("Underlying connection is returned to the pool")] + ListenerClosedError, + #[error("Callback must be an async callable")] + ListenerCallbackError, + #[error("Can't convert value from driver to python type: {0}")] RustToPyValueConversionError(String), #[error("Can't convert value from python to rust type: {0}")] @@ -161,6 +171,14 @@ impl From for pyo3::PyErr { RustPSQLDriverError::CursorFetchError(_) => CursorFetchError::new_err((error_desc,)), RustPSQLDriverError::SSLError(_) => SSLError::new_err((error_desc,)), RustPSQLDriverError::CursorClosedError => CursorClosedError::new_err((error_desc,)), + RustPSQLDriverError::ListenerError(_) => BaseListenerError::new_err((error_desc,)), + RustPSQLDriverError::ListenerStartError(_) => { + ListenerStartError::new_err((error_desc,)) + } + RustPSQLDriverError::ListenerClosedError => ListenerClosedError::new_err((error_desc,)), + RustPSQLDriverError::ListenerCallbackError => { + ListenerCallbackError::new_err((error_desc,)) + } } } } diff --git a/src/lib.rs b/src/lib.rs index edda3119..e3602311 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,8 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; diff --git a/src/query_result.rs b/src/query_result.rs index 06299b86..162be3b5 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -16,7 +16,7 @@ fn row_to_dict<'a>( postgres_row: &'a Row, custom_decoders: &Option>, ) -> RustPSQLDriverPyResult> { - let python_dict = PyDict::new_bound(py); + let python_dict = PyDict::new(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; python_dict.set_item(column.name().to_object(py), python_type)?; @@ -85,7 +85,7 @@ impl PSQLDriverPyQueryResult { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?; - let convert_class_inst = as_class.call_bound(py, (), Some(&pydict))?; + let convert_class_inst = as_class.call(py, (), Some(&pydict))?; res.push(convert_class_inst); } @@ -109,7 +109,7 @@ impl PSQLDriverPyQueryResult { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &custom_decoders)?; - let row_factory_class = row_factory.call_bound(py, (pydict,), None)?; + let row_factory_class = row_factory.call(py, (pydict,), None)?; res.push(row_factory_class); } Ok(res.to_object(py)) @@ -170,7 +170,7 @@ impl PSQLDriverSinglePyQueryResult { as_class: Py, ) -> RustPSQLDriverPyResult> { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?; - Ok(as_class.call_bound(py, (), Some(&pydict))?) + Ok(as_class.call(py, (), Some(&pydict))?) } /// Convert result from database with function passed from Python. @@ -188,6 +188,6 @@ impl PSQLDriverSinglePyQueryResult { custom_decoders: Option>, ) -> RustPSQLDriverPyResult> { let pydict = row_to_dict(py, &self.inner, &custom_decoders)?.to_object(py); - Ok(row_factory.call_bound(py, (pydict,), None)?) + Ok(row_factory.call(py, (pydict,), None)?) } } diff --git a/src/runtime.rs b/src/runtime.rs index 21365d4f..05889d99 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,5 +1,5 @@ use futures_util::Future; -use pyo3::{IntoPy, Py, PyAny, PyObject, Python}; +use pyo3::{IntoPyObject, Py, PyAny, Python}; use crate::exceptions::rust_errors::RustPSQLDriverPyResult; @@ -21,7 +21,7 @@ pub fn tokio_runtime() -> &'static tokio::runtime::Runtime { pub fn rustdriver_future(py: Python<'_>, future: F) -> RustPSQLDriverPyResult> where F: Future> + Send + 'static, - T: IntoPy, + T: for<'py> IntoPyObject<'py>, { let res = pyo3_async_runtimes::tokio::future_into_py(py, async { future.await.map_err(Into::into) })