diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 077354e8..f7260aed 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -43,10 +43,11 @@ jobs: apt update -y && apt-get install -y libssl-dev openssl pkg-config fi - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-linux path: dist + overwrite: true - name: Releasing assets uses: softprops/action-gh-release@v1 with: @@ -73,10 +74,11 @@ jobs: args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 sccache: 'true' - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-windows path: dist + overwrite: true - name: Releasing assets uses: softprops/action-gh-release@v1 with: @@ -113,10 +115,11 @@ jobs: args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 sccache: 'true' - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-macos path: dist + overwrite: true - name: Releasing assets uses: softprops/action-gh-release@v1 with: @@ -135,10 +138,11 @@ jobs: command: sdist args: --out dist - name: Upload sdist - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-sdist path: dist + overwrite: true - name: Releasing assets uses: softprops/action-gh-release@v1 with: @@ -167,10 +171,11 @@ jobs: args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 manylinux: musllinux_1_2 - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: wheels + name: wheels-musllinux path: dist + overwrite: true - name: Releasing assets uses: softprops/action-gh-release@v1 with: @@ -184,9 +189,10 @@ jobs: runs-on: ubuntu-latest needs: [linux, windows, macos, musllinux, sdist] steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: - name: wheels + pattern: wheels-* + merge-multiple: true - name: Publish to PyPI uses: PyO3/maturin-action@v1 env: diff --git a/Cargo.lock b/Cargo.lock index abf26770..8ef797d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -997,7 +997,7 @@ dependencies = [ [[package]] name = "psqlpy" -version = "0.9.0" +version = "0.9.1" dependencies = [ "byteorder", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 5e7743de..b30cc53c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "psqlpy" -version = "0.9.0" +version = "0.9.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/README.md b/README.md index 27ccecc8..1272db9c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/psqlpy?style=for-the-badge)](https://pypi.org/project/psqlpy/) +[![PyPI - Python Version](https://img.shields.io/badge/PYTHON-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue?style=for-the-badge +)](https://pypi.org/project/psqlpy/) [![PyPI](https://img.shields.io/pypi/v/psqlpy?style=for-the-badge)](https://pypi.org/project/psqlpy/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/psqlpy?style=for-the-badge)](https://pypistats.org/packages/psqlpy) @@ -54,9 +55,10 @@ async def main() -> None: max_db_pool_size=2, ) - res: QueryResult = await db_pool.execute( - "SELECT * FROM users", - ) + async with db_pool.acquire() as conn: + res: QueryResult = await conn.execute( + "SELECT * FROM users", + ) print(res.result()) db_pool.close() diff --git a/docs/.vuepress/sidebar.ts b/docs/.vuepress/sidebar.ts index 133833f7..94b082f9 100644 --- a/docs/.vuepress/sidebar.ts +++ b/docs/.vuepress/sidebar.ts @@ -67,6 +67,15 @@ export default sidebar({ }, ], }, + { + text: "Integrations", + prefix: "/integrations", + collapsible: true, + children: [ + "taskiq", + "opentelemetry", + ], + }, { text: "Contribution guide", prefix: "/contribution_guide", diff --git a/docs/components/connection_pool.md b/docs/components/connection_pool.md index bcee4893..514f899d 100644 --- a/docs/components/connection_pool.md +++ b/docs/components/connection_pool.md @@ -178,55 +178,6 @@ It has 4 parameters: - `available` - available connection in the connection pool. - `waiting` - waiting requests to retrieve connection from connection pool. -### Execute - -#### Parameters: - -- `querystring`: Statement string. -- `parameters`: List of parameters for the statement string. -- `prepared`: Prepare statement before execution or not. - -You can execute any query directly from Connection Pool. -This method supports parameters, each parameter must be marked as `$` (number starts with 1). -Parameters must be passed as list after querystring. -::: caution -You must use `ConnectionPool.execute` method in high-load production code wisely! -It pulls connection from the pool each time you execute query. -Preferable way to execute statements with [Connection](./../components/connection.md) or [Transaction](./../components/transaction.md) -::: - -```python -async def main() -> None: - ... - results: QueryResult = await db_pool.execute( - "SELECT * FROM users WHERE id = $1 and username = $2", - [100, "Alex"], - ) - - dict_results: list[dict[str, Any]] = results.result() -``` - -### Fetch - -#### Parameters: - -- `querystring`: Statement string. -- `parameters`: List of parameters for the statement string. -- `prepared`: Prepare statement before execution or not. - -The same as the `execute` method, for some people this naming is preferable. - -```python -async def main() -> None: - ... - results: QueryResult = await db_pool.fetch( - "SELECT * FROM users WHERE id = $1 and username = $2", - [100, "Alex"], - ) - - dict_results: list[dict[str, Any]] = results.result() -``` - ### Acquire Get single connection for async context manager. diff --git a/docs/components/listener.md b/docs/components/listener.md index f2ecff30..000067ac 100644 --- a/docs/components/listener.md +++ b/docs/components/listener.md @@ -121,7 +121,7 @@ async def main() -> None: - `channel`: name of the channel to listen. - `callback`: coroutine callback. -Add new callback to the channel, can be called more than 1 times. +Add new callback to the channel, can be called multiple times (before or after `listen`). Callback signature is like this: ```python @@ -196,9 +196,17 @@ In the background it creates task in Rust event loop. async def main() -> None: listener = db_pool.listener() await listener.startup() - await listener.listen() + listener.listen() ``` ### Abort Listen Abort listen. If `listen()` method was called, stop listening, else don't do anything. + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.startup() + listener.listen() + listener.abort_listen() +``` diff --git a/docs/integrations/opentelemetry.md b/docs/integrations/opentelemetry.md new file mode 100644 index 00000000..a8461034 --- /dev/null +++ b/docs/integrations/opentelemetry.md @@ -0,0 +1,8 @@ +--- +title: Integration with OpenTelemetry +--- + +# OTLP-PSQLPy + +There is a library for OpenTelemetry support. +Please follow the [link](https://github.com/psqlpy-python/otlp-psqlpy) diff --git a/docs/integrations/taskiq.md b/docs/integrations/taskiq.md new file mode 100644 index 00000000..97579347 --- /dev/null +++ b/docs/integrations/taskiq.md @@ -0,0 +1,8 @@ +--- +title: Integration with TaskIQ +--- + +# TaskIQ-PSQLPy + +There is integration with [TaskIQ](https://github.com/taskiq-python/taskiq-psqlpy). +You can use PSQLPy for result backend. diff --git a/docs/introduction/lets_start.md b/docs/introduction/lets_start.md index db21c627..0ab5f4bc 100644 --- a/docs/introduction/lets_start.md +++ b/docs/introduction/lets_start.md @@ -40,7 +40,7 @@ Let's assume that we have table `users`: ```python import asyncio -from typing import Final +from typing import Final, Any from psqlpy import ConnectionPool, QueryResult @@ -49,20 +49,16 @@ async def main() -> None: # It uses default connection parameters db_pool: Final = ConnectionPool() - results: Final[QueryResult] = await db_pool.execute( - "SELECT * FROM users WHERE id = $1", - [2], - ) + async with db_pool.acquire() as conn: + results: Final[QueryResult] = await conn.execute( + "SELECT * FROM users WHERE id = $1", + [2], + ) dict_results: Final[list[dict[Any, Any]]] = results.result() - db.close() + db_pool.close() ``` ::: tip -You must call `close()` on database pool when you application is shutting down. -::: -::: caution -You must not use `ConnectionPool.execute` method in high-load production code! -It pulls new connection from connection pull each call. -Recommended way to make queries is executing them with `Connection`, `Transaction` or `Cursor`. +It's better to call `close()` on database pool when you application is shutting down. ::: diff --git a/psqlpy-stress/psqlpy_stress/mocker.py b/psqlpy-stress/psqlpy_stress/mocker.py index 156f10df..de55fc59 100644 --- a/psqlpy-stress/psqlpy_stress/mocker.py +++ b/psqlpy-stress/psqlpy_stress/mocker.py @@ -17,8 +17,9 @@ def get_pool() -> psqlpy.ConnectionPool: async def fill_users() -> None: pool = get_pool() users_amount = 10000000 + connection = await pool.connection() for _ in range(users_amount): - await pool.execute( + await connection.execute( querystring="INSERT INTO users (username) VALUES($1)", parameters=[str(uuid.uuid4())], ) @@ -35,8 +36,9 @@ def generate_random_dict() -> dict[str, str]: async def fill_big_table() -> None: pool = get_pool() big_table_amount = 10000000 + connection = await pool.connection() for _ in range(big_table_amount): - await pool.execute( + await connection.execute( "INSERT INTO big_table (string_field, integer_field, json_field, array_field) VALUES($1, $2, $3, $4)", parameters=[ str(uuid.uuid4()), diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 42b836b2..77ec440d 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -288,6 +288,16 @@ class Cursor: It can be used as an asynchronous iterator. """ + cursor_name: str + querystring: str + parameters: Sequence[Any] + prepared: bool | None + conn_dbname: str | None + user: str | None + host_addrs: list[str] + hosts: list[str] + ports: list[int] + def __aiter__(self: Self) -> Self: ... async def __anext__(self: Self) -> QueryResult: ... async def __aenter__(self: Self) -> Self: ... @@ -424,6 +434,12 @@ class Transaction: `.transaction()`. """ + conn_dbname: str | None + user: str | None + host_addrs: list[str] + hosts: list[str] + ports: list[int] + async def __aenter__(self: Self) -> Self: ... async def __aexit__( self: Self, @@ -874,6 +890,12 @@ class Connection: It can be created only from connection pool. """ + conn_dbname: str | None + user: str | None + host_addrs: list[str] + hosts: list[str] + ports: list[int] + async def __aenter__(self: Self) -> Self: ... async def __aexit__( self: Self, @@ -1284,60 +1306,6 @@ class ConnectionPool: ### Parameters: - `new_max_size`: new size for the connection pool. """ - async def execute( - self: Self, - querystring: str, - parameters: Sequence[Any] | None = None, - prepared: bool = True, - ) -> QueryResult: - """Execute the query. - - Querystring can contain `$` parameters - for converting them in the driver side. - - ### Parameters: - - `querystring`: querystring to execute. - - `parameters`: list of parameters to pass in the query. - - `prepared`: should the querystring be prepared before the request. - By default any querystring will be prepared. - - ### Example: - ```python - import asyncio - - from psqlpy import PSQLPool, QueryResult - - async def main() -> None: - db_pool = PSQLPool() - query_result: QueryResult = await psqlpy.execute( - "SELECT username FROM users WHERE id = $1", - [100], - ) - dict_result: List[Dict[Any, Any]] = query_result.result() - # you don't need to close the pool, - # it will be dropped on Rust side. - ``` - """ - async def fetch( - self: Self, - querystring: str, - parameters: Sequence[Any] | None = None, - prepared: bool = True, - ) -> QueryResult: - """Fetch the result from database. - - It's the same as `execute` method, we made it because people are used - to `fetch` method name. - - Querystring can contain `$` parameters - for converting them in the driver side. - - ### Parameters: - - `querystring`: querystring to execute. - - `parameters`: list of parameters to pass in the query. - - `prepared`: should the querystring be prepared before the request. - By default any querystring will be prepared. - """ async def connection(self: Self) -> Connection: """Create new connection. diff --git a/python/tests/conftest.py b/python/tests/conftest.py index bfa3f650..4a388f62 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -126,18 +126,19 @@ async def create_default_data_for_tests( table_name: str, number_database_records: int, ) -> AsyncGenerator[None, None]: - await psql_pool.execute( + connection = await psql_pool.connection() + await connection.execute( f"CREATE TABLE {table_name} (id SERIAL, name VARCHAR(255))", ) for table_id in range(1, number_database_records + 1): new_name = random_string() - await psql_pool.execute( + await connection.execute( querystring=f"INSERT INTO {table_name} VALUES ($1, $2)", parameters=[table_id, new_name], ) yield - await psql_pool.execute( + await connection.execute( f"DROP TABLE {table_name}", ) @@ -147,14 +148,15 @@ async def create_table_for_listener_tests( psql_pool: ConnectionPool, listener_table_name: str, ) -> AsyncGenerator[None, None]: - await psql_pool.execute( + connection = await psql_pool.connection() + await connection.execute( f"CREATE TABLE {listener_table_name}" f"(id SERIAL, payload VARCHAR(255)," f"channel VARCHAR(255), process_id INT)", ) yield - await psql_pool.execute( + await connection.execute( f"DROP TABLE {listener_table_name}", ) diff --git a/python/tests/test_binary_copy.py b/python/tests/test_binary_copy.py index 93cc1335..3dcfe678 100644 --- a/python/tests/test_binary_copy.py +++ b/python/tests/test_binary_copy.py @@ -15,8 +15,9 @@ async def test_binary_copy_to_table_in_connection( ) -> None: """Test binary copy in connection.""" table_name: typing.Final = "cars" - await psql_pool.execute(f"DROP TABLE IF EXISTS {table_name}") - await psql_pool.execute( + connection = await psql_pool.connection() + await connection.execute(f"DROP TABLE IF EXISTS {table_name}") + await connection.execute( """ CREATE TABLE IF NOT EXISTS cars ( model VARCHAR, @@ -46,17 +47,16 @@ async def test_binary_copy_to_table_in_connection( buf.write(encoder.finish()) buf.seek(0) - async with psql_pool.acquire() as connection: - inserted_rows = await connection.binary_copy_to_table( - source=buf, - table_name=table_name, - ) + inserted_rows = await connection.binary_copy_to_table( + source=buf, + table_name=table_name, + ) expected_inserted_row: typing.Final = 32 assert inserted_rows == expected_inserted_row - real_table_rows: typing.Final = await psql_pool.execute( + real_table_rows: typing.Final = await connection.execute( f"SELECT COUNT(*) AS rows_count FROM {table_name}", ) assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row @@ -67,8 +67,10 @@ async def test_binary_copy_to_table_in_transaction( ) -> None: """Test binary copy in transaction.""" table_name: typing.Final = "cars" - await psql_pool.execute(f"DROP TABLE IF EXISTS {table_name}") - await psql_pool.execute( + + connection = await psql_pool.connection() + await connection.execute(f"DROP TABLE IF EXISTS {table_name}") + await connection.execute( """ CREATE TABLE IF NOT EXISTS cars ( model VARCHAR, @@ -108,7 +110,8 @@ async def test_binary_copy_to_table_in_transaction( assert inserted_rows == expected_inserted_row - real_table_rows: typing.Final = await psql_pool.execute( + connection = await psql_pool.connection() + real_table_rows: typing.Final = await connection.execute( f"SELECT COUNT(*) AS rows_count FROM {table_name}", ) assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 3c15991a..898cc405 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -180,8 +180,9 @@ async def test_closed_connection_error( async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: """Test `execute_batch` method.""" - await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch") - await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch2") + connection = await psql_pool.connection() + await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch") + await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch2") query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);" async with psql_pool.acquire() as conn: await conn.execute_batch(querystring=query) diff --git a/python/tests/test_connection_pool.py b/python/tests/test_connection_pool.py index cdf2fa48..405fceb7 100644 --- a/python/tests/test_connection_pool.py +++ b/python/tests/test_connection_pool.py @@ -4,7 +4,6 @@ ConnectionPool, ConnRecyclingMethod, LoadBalanceHosts, - QueryResult, TargetSessionAttrs, connect, ) @@ -22,7 +21,8 @@ async def test_connect_func() -> None: dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test", ) - await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + await conn.execute("SELECT 1") async def test_pool_dsn_startup() -> None: @@ -31,41 +31,8 @@ async def test_pool_dsn_startup() -> None: dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test", ) - await pg_pool.execute("SELECT 1") - - -async def test_pool_execute( - psql_pool: ConnectionPool, - table_name: str, - number_database_records: int, -) -> None: - """Test that ConnectionPool can execute queries.""" - select_result = await psql_pool.execute( - f"SELECT * FROM {table_name}", - ) - - assert type(select_result) == QueryResult - - inner_result = select_result.result() - assert isinstance(inner_result, list) - assert len(inner_result) == number_database_records - - -async def test_pool_fetch( - psql_pool: ConnectionPool, - table_name: str, - number_database_records: int, -) -> None: - """Test that ConnectionPool can fetch queries.""" - select_result = await psql_pool.fetch( - f"SELECT * FROM {table_name}", - ) - - assert type(select_result) == QueryResult - - inner_result = select_result.result() - assert isinstance(inner_result, list) - assert len(inner_result) == number_database_records + conn = await pg_pool.connection() + await conn.execute("SELECT 1") async def test_pool_connection( @@ -92,7 +59,8 @@ async def test_pool_conn_recycling_method( conn_recycling_method=conn_recycling_method, ) - await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + await conn.execute("SELECT 1") async def test_build_pool_failure() -> None: @@ -139,9 +107,10 @@ async def test_pool_target_session_attrs( if target_session_attrs == TargetSessionAttrs.ReadOnly: with pytest.raises(expected_exception=RustPSQLDriverPyBaseError): - await pg_pool.execute("SELECT 1") + await pg_pool.connection() else: - await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + await conn.execute("SELECT 1") @pytest.mark.parametrize( @@ -159,7 +128,8 @@ async def test_pool_load_balance_hosts( load_balance_hosts=load_balance_hosts, ) - await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + await conn.execute("SELECT 1") async def test_close_connection_pool() -> None: @@ -168,12 +138,13 @@ async def test_close_connection_pool() -> None: dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test", ) - await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + await conn.execute("SELECT 1") pg_pool.close() with pytest.raises(expected_exception=RustPSQLDriverPyBaseError): - await pg_pool.execute("SELECT 1") + await pg_pool.connection() async def test_connection_pool_as_context_manager() -> None: @@ -181,8 +152,9 @@ async def test_connection_pool_as_context_manager() -> None: with ConnectionPool( dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test", ) as pg_pool: - res = await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + res = await conn.execute("SELECT 1") assert res.result() with pytest.raises(expected_exception=RustPSQLDriverPyBaseError): - await pg_pool.execute("SELECT 1") + await pg_pool.connection() diff --git a/python/tests/test_connection_pool_builder.py b/python/tests/test_connection_pool_builder.py index f937bec3..7c9de409 100644 --- a/python/tests/test_connection_pool_builder.py +++ b/python/tests/test_connection_pool_builder.py @@ -48,8 +48,8 @@ async def test_connection_pool_builder( ) pool = builder.build() - - results = await pool.execute( + connection = await pool.connection() + results = await connection.execute( querystring=f"SELECT * FROM {table_name}", ) diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py index 46722ca1..c48c8974 100644 --- a/python/tests/test_listener.py +++ b/python/tests/test_listener.py @@ -62,7 +62,9 @@ async def notify( if with_delay: await asyncio.sleep(0.5) - await psql_pool.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") + connection = await psql_pool.connection() + await connection.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") + connection.back_to_pool() async def check_insert_callback( @@ -71,8 +73,9 @@ async def check_insert_callback( is_insert_exist: bool = True, number_of_data: int = 1, ) -> None: + connection = await psql_pool.connection() test_data_seq = ( - await psql_pool.execute( + await connection.execute( f"SELECT * FROM {listener_table_name}", ) ).result() @@ -88,14 +91,18 @@ async def check_insert_callback( assert data_record["payload"] == TEST_PAYLOAD assert data_record["channel"] == TEST_CHANNEL + connection.back_to_pool() + async def clear_test_table( psql_pool: ConnectionPool, listener_table_name: str, ) -> None: - await psql_pool.execute( + connection = await psql_pool.connection() + await connection.execute( f"DELETE FROM {listener_table_name}", ) + connection.back_to_pool() @pytest.mark.usefixtures("create_table_for_listener_tests") @@ -244,7 +251,8 @@ async def test_listener_more_than_one_callback( number_of_data=2, ) - query_result = await psql_pool.execute( + connection = await psql_pool.connection() + query_result = await connection.execute( querystring=(f"SELECT * FROM {listener_table_name} WHERE channel = $1"), parameters=(additional_channel,), ) diff --git a/python/tests/test_row_factories.py b/python/tests/test_row_factories.py index 75d03e5a..9b7f3121 100644 --- a/python/tests/test_row_factories.py +++ b/python/tests/test_row_factories.py @@ -13,7 +13,8 @@ async def test_tuple_row( table_name: str, number_database_records: int, ) -> None: - conn_result = await psql_pool.execute( + connection = await psql_pool.connection() + conn_result = await connection.execute( querystring=f"SELECT * FROM {table_name}", ) tuple_res = conn_result.row_factory(row_factory=tuple_row) @@ -32,7 +33,8 @@ class ValidationTestModel: id: int name: str - conn_result = await psql_pool.execute( + connection = await psql_pool.connection() + conn_result = await connection.execute( querystring=f"SELECT * FROM {table_name}", ) class_res = conn_result.row_factory(row_factory=class_row(ValidationTestModel)) @@ -58,7 +60,8 @@ def to_class_inner(row: Dict[str, Any]) -> ValidationTestModel: return to_class_inner - conn_result = await psql_pool.execute( + connection = await psql_pool.connection() + conn_result = await connection.execute( querystring=f"SELECT * FROM {table_name}", ) class_res = conn_result.row_factory(row_factory=to_class(ValidationTestModel)) diff --git a/python/tests/test_ssl_mode.py b/python/tests/test_ssl_mode.py index 53978d9e..4c72014e 100644 --- a/python/tests/test_ssl_mode.py +++ b/python/tests/test_ssl_mode.py @@ -35,7 +35,8 @@ async def test_ssl_mode_require( ca_file=ssl_cert_file, ) - await pg_pool.execute("SELECT 1") + conn = await pg_pool.connection() + await conn.execute("SELECT 1") @pytest.mark.parametrize( @@ -72,7 +73,8 @@ async def test_ssl_mode_require_pool_builder( pool = builder.build() - await pool.execute("SELECT 1") + connection = await pool.connection() + await connection.execute("SELECT 1") async def test_ssl_mode_require_without_ca_file( @@ -94,4 +96,5 @@ async def test_ssl_mode_require_without_ca_file( ) pool = builder.build() - await pool.execute("SELECT 1") + connection = await pool.connection() + await connection.execute("SELECT 1") diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index 7704393b..151f5bb5 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -39,8 +39,7 @@ async def test_transaction_init_parameters( deferrable: bool | None, read_variant: ReadVariant | None, ) -> None: - connection = await psql_pool.connection() - async with connection.transaction( + async with psql_pool.acquire() as connection, connection.transaction( isolation_level=isolation_level, deferrable=deferrable, read_variant=read_variant, @@ -79,6 +78,8 @@ async def test_transaction_begin( assert len(result.result()) == number_database_records + await transaction.commit() + async def test_transaction_commit( psql_pool: ConnectionPool, @@ -97,7 +98,8 @@ async def test_transaction_commit( # Make request from other connection, it mustn't know # about new INSERT data before commit. - result = await psql_pool.execute( + connection = await psql_pool.connection() + result = await connection.execute( f"SELECT * FROM {table_name} WHERE name = $1", parameters=[test_name], ) @@ -105,7 +107,7 @@ async def test_transaction_commit( await transaction.commit() - result = await psql_pool.execute( + result = await connection.execute( f"SELECT * FROM {table_name} WHERE name = $1", parameters=[test_name], ) @@ -136,7 +138,8 @@ async def test_transaction_savepoint( assert result.result() await transaction.rollback_savepoint(savepoint_name=savepoint_name) - result = await psql_pool.execute( + connection = await psql_pool.connection() + result = await connection.execute( f"SELECT * FROM {table_name} WHERE name = $1", parameters=[test_name], ) @@ -174,10 +177,12 @@ async def test_transaction_rollback( parameters=[test_name], ) - result_from_conn = await psql_pool.execute( + connection = await psql_pool.connection() + result_from_conn = await connection.execute( f"INSERT INTO {table_name} VALUES ($1, $2)", parameters=[100, test_name], ) + connection.back_to_pool() assert not (result_from_conn.result()) @@ -344,14 +349,17 @@ async def test_transaction_send_underlying_connection_to_pool_manually( async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: """Test `execute_batch` method.""" - await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch") - await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch2") + connection = await psql_pool.connection() + await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch") + await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch2") query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);" - async with psql_pool.acquire() as conn, conn.transaction() as transaction: + async with connection.transaction() as transaction: await transaction.execute_batch(querystring=query) await transaction.execute(querystring="SELECT * FROM execute_batch") await transaction.execute(querystring="SELECT * FROM execute_batch2") + connection.back_to_pool() + @pytest.mark.parametrize( "synchronous_commit", diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 414b051f..de62c554 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -103,7 +103,8 @@ async def test_as_class( number_database_records: int, ) -> None: """Test `as_class()` method.""" - select_result = await psql_pool.execute( + connection = await psql_pool.connection() + select_result = await connection.execute( f"SELECT * FROM {table_name}", ) @@ -649,20 +650,21 @@ async def test_deserialization_simple_into_python( expected_deserialized: Any, ) -> None: """Test how types can cast from Python and to Python.""" - await psql_pool.execute("DROP TABLE IF EXISTS for_test") + connection = await psql_pool.connection() + await connection.execute("DROP TABLE IF EXISTS for_test") create_table_query = f""" CREATE TABLE for_test (test_field {postgres_type}) """ insert_data_query = """ INSERT INTO for_test VALUES ($1) """ - await psql_pool.execute(querystring=create_table_query) - await psql_pool.execute( + await connection.execute(querystring=create_table_query) + await connection.execute( querystring=insert_data_query, parameters=[py_value], ) - raw_result = await psql_pool.execute( + raw_result = await connection.execute( querystring="SELECT test_field FROM for_test", ) @@ -673,12 +675,13 @@ async def test_deserialization_composite_into_python( psql_pool: ConnectionPool, ) -> None: """Test that it's possible to deserialize custom postgresql type.""" - await psql_pool.execute("DROP TABLE IF EXISTS for_test") - await psql_pool.execute("DROP TYPE IF EXISTS all_types") - await psql_pool.execute("DROP TYPE IF EXISTS inner_type") - await psql_pool.execute("DROP TYPE IF EXISTS enum_type") - await psql_pool.execute("CREATE TYPE enum_type AS ENUM ('sad', 'ok', 'happy')") - await psql_pool.execute("CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)") + connection = await psql_pool.connection() + await connection.execute("DROP TABLE IF EXISTS for_test") + await connection.execute("DROP TYPE IF EXISTS all_types") + await connection.execute("DROP TYPE IF EXISTS inner_type") + await connection.execute("DROP TYPE IF EXISTS enum_type") + await connection.execute("CREATE TYPE enum_type AS ENUM ('sad', 'ok', 'happy')") + await connection.execute("CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)") create_type_query = """ CREATE type all_types AS ( bytea_ BYTEA, @@ -736,10 +739,10 @@ async def test_deserialization_composite_into_python( CREATE table for_test (custom_type all_types) """ - await psql_pool.execute( + await connection.execute( querystring=create_type_query, ) - await psql_pool.execute( + await connection.execute( querystring=create_table_query, ) @@ -752,7 +755,7 @@ class TestEnum(Enum): row_values += ", ROW($41, $42), " row_values += ", ".join([f"${index}" for index in range(43, 50)]) - await psql_pool.execute( + await connection.execute( querystring=f"INSERT INTO for_test VALUES (ROW({row_values}))", parameters=[ b"Bytes", @@ -914,7 +917,7 @@ class ValidateModelForCustomType(BaseModel): class TopLevelModel(BaseModel): custom_type: ValidateModelForCustomType - query_result = await psql_pool.execute( + query_result = await connection.execute( "SELECT custom_type FROM for_test", ) @@ -938,21 +941,22 @@ class TestStrEnum(str, Enum): SAD = "sad" HAPPY = "happy" - await psql_pool.execute("DROP TABLE IF EXISTS for_test") - await psql_pool.execute("DROP TYPE IF EXISTS mood") - await psql_pool.execute( + connection = await psql_pool.connection() + await connection.execute("DROP TABLE IF EXISTS for_test") + await connection.execute("DROP TYPE IF EXISTS mood") + await connection.execute( "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')", ) - await psql_pool.execute( + await connection.execute( "CREATE TABLE for_test (test_mood mood, test_mood2 mood)", ) - await psql_pool.execute( + await connection.execute( querystring="INSERT INTO for_test VALUES ($1, $2)", parameters=[TestEnum.HAPPY, TestEnum.OK], ) - qs_result = await psql_pool.execute( + qs_result = await connection.execute( "SELECT * FROM for_test", ) assert qs_result.result()[0]["test_mood"] == TestEnum.HAPPY.value @@ -964,17 +968,18 @@ async def test_custom_type_as_parameter( psql_pool: ConnectionPool, ) -> None: """Tests that we can use `PyCustomType`.""" - await psql_pool.execute("DROP TABLE IF EXISTS for_test") - await psql_pool.execute( + connection = await psql_pool.connection() + await connection.execute("DROP TABLE IF EXISTS for_test") + await connection.execute( "CREATE TABLE for_test (nickname VARCHAR)", ) - await psql_pool.execute( + await connection.execute( querystring="INSERT INTO for_test VALUES ($1)", parameters=[CustomType(b"Some Real Nickname")], ) - qs_result = await psql_pool.execute( + qs_result = await connection.execute( "SELECT * FROM for_test", ) @@ -985,28 +990,29 @@ async def test_custom_type_as_parameter( async def test_custom_decoder( psql_pool: ConnectionPool, ) -> None: - await psql_pool.execute("DROP TABLE IF EXISTS for_test") - await psql_pool.execute( - "CREATE TABLE for_test (geo_point POINT)", - ) - - await psql_pool.execute( - "INSERT INTO for_test VALUES ('(1, 1)')", - ) - def point_encoder(point_bytes: bytes) -> str: # noqa: ARG001 return "Just An Example" - qs_result = await psql_pool.execute( - "SELECT * FROM for_test", - ) - result = qs_result.result( - custom_decoders={ - "geo_point": point_encoder, - }, - ) + async with psql_pool.acquire() as conn: + await conn.execute("DROP TABLE IF EXISTS for_test") + await conn.execute( + "CREATE TABLE for_test (geo_point POINT)", + ) - assert result[0]["geo_point"] == "Just An Example" + await conn.execute( + "INSERT INTO for_test VALUES ('(1, 1)')", + ) + + qs_result = await conn.execute( + "SELECT * FROM for_test", + ) + result = qs_result.result( + custom_decoders={ + "geo_point": point_encoder, + }, + ) + + assert result[0]["geo_point"] == "Just An Example" async def test_row_factory_query_result( @@ -1014,77 +1020,80 @@ async def test_row_factory_query_result( table_name: str, number_database_records: int, ) -> None: - select_result = await psql_pool.execute( - f"SELECT * FROM {table_name}", - ) + async with psql_pool.acquire() as conn: + select_result = await conn.execute( + f"SELECT * FROM {table_name}", + ) - def row_factory(db_result: Dict[str, Any]) -> List[str]: - return list(db_result.keys()) + def row_factory(db_result: Dict[str, Any]) -> List[str]: + return list(db_result.keys()) - as_row_factory = select_result.row_factory( - row_factory=row_factory, - ) - assert len(as_row_factory) == number_database_records + as_row_factory = select_result.row_factory( + row_factory=row_factory, + ) + assert len(as_row_factory) == number_database_records - assert isinstance(as_row_factory[0], list) + assert isinstance(as_row_factory[0], list) async def test_row_factory_single_query_result( psql_pool: ConnectionPool, table_name: str, ) -> None: - connection = await psql_pool.connection() - select_result = await connection.fetch_row( - f"SELECT * FROM {table_name} LIMIT 1", - ) + async with psql_pool.acquire() as conn: + select_result = await conn.fetch_row( + f"SELECT * FROM {table_name} LIMIT 1", + ) - def row_factory(db_result: Dict[str, Any]) -> List[str]: - return list(db_result.keys()) + def row_factory(db_result: Dict[str, Any]) -> List[str]: + return list(db_result.keys()) - as_row_factory = select_result.row_factory( - row_factory=row_factory, - ) - expected_number_of_elements_in_result = 2 - assert len(as_row_factory) == expected_number_of_elements_in_result + as_row_factory = select_result.row_factory( + row_factory=row_factory, + ) + expected_number_of_elements_in_result = 2 + assert len(as_row_factory) == expected_number_of_elements_in_result - assert isinstance(as_row_factory, list) + assert isinstance(as_row_factory, list) async def test_incorrect_dimensions_array( psql_pool: ConnectionPool, ) -> None: - await psql_pool.execute("DROP TABLE IF EXISTS test_marr") - await psql_pool.execute("CREATE TABLE test_marr (var_array VARCHAR ARRAY)") - - with pytest.raises(expected_exception=PyToRustValueMappingError): - await psql_pool.execute( - querystring="INSERT INTO test_marr VALUES ($1)", - parameters=[ - [ - ["Len", "is", "Three"], - ["Len", "is", "Four", "Wow"], + async with psql_pool.acquire() as conn: + await conn.execute("DROP TABLE IF EXISTS test_marr") + await conn.execute("CREATE TABLE test_marr (var_array VARCHAR ARRAY)") + + with pytest.raises(expected_exception=PyToRustValueMappingError): + await conn.execute( + querystring="INSERT INTO test_marr VALUES ($1)", + parameters=[ + [ + ["Len", "is", "Three"], + ["Len", "is", "Four", "Wow"], + ], ], - ], - ) + ) async def test_empty_array( psql_pool: ConnectionPool, ) -> None: - await psql_pool.execute("DROP TABLE IF EXISTS test_earr") - await psql_pool.execute( - "CREATE TABLE test_earr (id serial NOT NULL PRIMARY KEY, e_array text[] NOT NULL DEFAULT array[]::text[])", - ) + async with psql_pool.acquire() as conn: + await conn.execute("DROP TABLE IF EXISTS test_earr") + await conn.execute( + "CREATE TABLE test_earr (id serial NOT NULL PRIMARY KEY, e_array text[] NOT NULL DEFAULT array[]::text[])", + ) - await psql_pool.execute("INSERT INTO test_earr(id) VALUES(2);") + await conn.execute("INSERT INTO test_earr(id) VALUES(2);") - res = await psql_pool.execute( - "SELECT * FROM test_earr WHERE id = 2", - ) + res = await conn.execute( + "SELECT * FROM test_earr WHERE id = 2", + ) - json_result = res.result() - assert json_result - assert not json_result[0]["e_array"] + json_result = res.result() + assert json_result + assert not json_result[0]["e_array"] @pytest.mark.parametrize( @@ -1557,21 +1566,22 @@ async def test_array_types( py_value: Any, expected_deserialized: Any, ) -> None: - await psql_pool.execute("DROP TABLE IF EXISTS for_test") - create_table_query = f""" - CREATE TABLE for_test (test_field {postgres_type}) - """ - insert_data_query = """ - INSERT INTO for_test VALUES ($1) - """ - await psql_pool.execute(querystring=create_table_query) - await psql_pool.execute( - querystring=insert_data_query, - parameters=[py_value], - ) + async with psql_pool.acquire() as conn: + await conn.execute("DROP TABLE IF EXISTS for_test") + create_table_query = f""" + CREATE TABLE for_test (test_field {postgres_type}) + """ + insert_data_query = """ + INSERT INTO for_test VALUES ($1) + """ + await conn.execute(querystring=create_table_query) + await conn.execute( + querystring=insert_data_query, + parameters=[py_value], + ) - raw_result = await psql_pool.execute( - querystring="SELECT test_field FROM for_test", - ) + raw_result = await conn.execute( + querystring="SELECT test_field FROM for_test", + ) - assert raw_result.result()[0]["test_field"] == expected_deserialized + assert raw_result.result()[0]["test_field"] == expected_deserialized diff --git a/src/common.rs b/src/common.rs index 8dc70fc3..d0ec15e4 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,13 +1,6 @@ use pyo3::{ types::{PyAnyMethods, PyModule, PyModuleMethods}, - Bound, PyAny, PyResult, Python, -}; - -use crate::{ - driver::connection::PsqlpyConnection, - exceptions::rust_errors::RustPSQLDriverPyResult, - query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{convert_parameters, PythonDTO, QueryParameter}, + Bound, PyResult, Python, }; /// Add new module to the parent one. @@ -33,104 +26,3 @@ pub fn add_module( )?; Ok(()) } - -pub trait ObjectQueryTrait { - fn psqlpy_query_one( - &self, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> impl std::future::Future> + Send; - - fn psqlpy_query( - &self, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> impl std::future::Future> + Send; - - fn psqlpy_query_simple( - &self, - querystring: String, - ) -> impl std::future::Future> + Send; -} - -impl ObjectQueryTrait for PsqlpyConnection { - async fn psqlpy_query_one( - &self, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> RustPSQLDriverPyResult { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - - let result = if prepared { - self.query_one( - &self.prepare_cached(&querystring).await?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - } else { - self.query_one( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - }; - - Ok(PSQLDriverSinglePyQueryResult::new(result)) - } - - async fn psqlpy_query( - &self, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> RustPSQLDriverPyResult { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - - let result = if prepared { - self.query( - &self.prepare_cached(&querystring).await?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - } else { - self.query( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - }; - - Ok(PSQLDriverPyQueryResult::new(result)) - } - - async fn psqlpy_query_simple(&self, querystring: String) -> RustPSQLDriverPyResult<()> { - self.batch_execute(querystring.as_str()).await - } -} diff --git a/src/driver/connection.rs b/src/driver/connection.rs index f10328c2..3c0595bb 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -1,125 +1,44 @@ -use bytes::{Buf, BytesMut}; -use deadpool_postgres::{Object, Pool}; +use bytes::BytesMut; +use deadpool_postgres::Pool; use futures_util::pin_mut; -use postgres_types::ToSql; use pyo3::{buffer::PyBuffer, pyclass, pymethods, Py, PyAny, PyErr, Python}; -use std::{collections::HashSet, sync::Arc, vec}; -use tokio_postgres::{ - binary_copy::BinaryCopyInWriter, Client, CopyInSink, Row, Statement, ToStatement, -}; +use std::{collections::HashSet, net::IpAddr, sync::Arc}; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, runtime::tokio_runtime, - value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter}, }; use super::{ cursor::Cursor, + inner_connection::PsqlpyConnection, transaction::Transaction, transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; -#[allow(clippy::module_name_repetitions)] -pub enum PsqlpyConnection { - PoolConn(Object), - SingleConn(Client), -} - -impl PsqlpyConnection { - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot prepare statement. - pub async fn prepare_cached(&self, query: &str) -> RustPSQLDriverPyResult { - match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), - PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), - } - } - - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot execute statement. - pub async fn query( - &self, - statement: &T, - params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult> - where - T: ?Sized + ToStatement, - { - match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.query(statement, params).await?), - PsqlpyConnection::SingleConn(sconn) => { - return Ok(sconn.query(statement, params).await?) - } - } - } - - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot execute statement. - pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { - match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), - PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), - } - } - - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot execute statement. - pub async fn query_one( - &self, - statement: &T, - params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult - where - T: ?Sized + ToStatement, - { - match self { - PsqlpyConnection::PoolConn(pconn) => { - return Ok(pconn.query_one(statement, params).await?) - } - PsqlpyConnection::SingleConn(sconn) => { - return Ok(sconn.query_one(statement, params).await?) - } - } - } - - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot execute copy data. - pub async fn copy_in(&self, statement: &T) -> RustPSQLDriverPyResult> - where - T: ?Sized + ToStatement, - U: Buf + 'static + Send, - { - match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.copy_in(statement).await?), - PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), - } - } -} - #[pyclass(subclass)] #[derive(Clone)] pub struct Connection { db_client: Option>, db_pool: Option, + pg_config: Arc, } impl Connection { #[must_use] - pub fn new(db_client: Option>, db_pool: Option) -> Self { - Connection { db_client, db_pool } + pub fn new( + db_client: Option>, + db_pool: Option, + pg_config: Arc, + ) -> Self { + Connection { + db_client, + db_pool, + pg_config, + } } #[must_use] @@ -135,12 +54,89 @@ impl Connection { impl Default for Connection { fn default() -> Self { - Connection::new(None, None) + Connection::new(None, None, Arc::new(Config::default())) } } #[pymethods] impl Connection { + #[getter] + fn conn_dbname(&self) -> Option<&str> { + self.pg_config.get_dbname() + } + + #[getter] + fn user(&self) -> Option<&str> { + self.pg_config.get_user() + } + + #[getter] + fn host_addrs(&self) -> Vec { + let mut host_addrs_vec = vec![]; + + let host_addrs = self.pg_config.get_hostaddrs(); + for ip_addr in host_addrs { + match ip_addr { + IpAddr::V4(ipv4) => { + host_addrs_vec.push(ipv4.to_string()); + } + IpAddr::V6(ipv6) => { + host_addrs_vec.push(ipv6.to_string()); + } + } + } + + host_addrs_vec + } + + #[cfg(unix)] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + Host::Unix(host) => { + hosts_vec.push(host.display().to_string()); + } + } + } + + hosts_vec + } + + #[cfg(not(unix))] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + _ => unreachable!(), + } + } + + hosts_vec + } + + #[getter] + fn ports(&self) -> Vec<&u16> { + return self.pg_config.get_ports().iter().collect::>(); + } + + #[getter] + fn options(&self) -> Option<&str> { + return self.pg_config.get_options(); + } + async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -213,54 +209,7 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - - let result = if prepared { - db_client - .query( - &db_client - .prepare_cached(&querystring) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - } else { - db_client - .query( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - }; - - return Ok(PSQLDriverPyQueryResult::new(result)); + return db_client.execute(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -311,60 +260,9 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - let mut params: Vec> = vec![]; - if let Some(parameters) = parameters { - for vec_of_py_any in parameters { - params.push(convert_parameters(vec_of_py_any)?); - } - } - let prepared = prepared.unwrap_or(true); - - db_client.batch_execute("BEGIN;").await.map_err(|err| { - RustPSQLDriverError::TransactionBeginError(format!( - "Cannot start transaction to run execute_many: {err}" - )) - })?; - for param in params { - let querystring_result = if prepared { - let prepared_stmt = &db_client.prepare_cached(&querystring).await; - if let Err(error) = prepared_stmt { - return Err(RustPSQLDriverError::TransactionExecuteError(format!( - "Cannot prepare statement in execute_many, operation rolled back {error}", - ))); - } - db_client - .query( - &db_client.prepare_cached(&querystring).await?, - ¶m - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - } else { - db_client - .query( - &querystring, - ¶m - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - }; - - if let Err(error) = querystring_result { - db_client.batch_execute("ROLLBACK;").await?; - return Err(RustPSQLDriverError::TransactionExecuteError(format!( - "Error occured in `execute_many` statement, transaction is rolled back: {error}" - ))); - } - } - db_client.batch_execute("COMMIT;").await?; - - return Ok(()); + return db_client + .execute_many(querystring, parameters, prepared) + .await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -388,54 +286,7 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - - let result = if prepared { - db_client - .query( - &db_client - .prepare_cached(&querystring) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - } else { - db_client - .query( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - }; - - return Ok(PSQLDriverPyQueryResult::new(result)); + return db_client.execute(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -465,54 +316,7 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - - let result = if prepared { - db_client - .query_one( - &db_client - .prepare_cached(&querystring) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - } else { - db_client - .query_one( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - }; - - return Ok(PSQLDriverSinglePyQueryResult::new(result)); + return db_client.fetch_row(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -539,57 +343,7 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - - let result = if prepared { - db_client - .query_one( - &db_client - .prepare_cached(&querystring) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - } else { - db_client - .query_one( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - }; - - return Python::with_gil(|gil| match result.columns().first() { - Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None), - None => Ok(gil.None()), - }); + return db_client.fetch_val(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -615,6 +369,7 @@ impl Connection { if let Some(db_client) = &self.db_client { return Ok(Transaction::new( db_client.clone(), + self.pg_config.clone(), false, false, isolation_level, @@ -650,6 +405,7 @@ impl Connection { if let Some(db_client) = &self.db_client { return Ok(Cursor::new( db_client.clone(), + self.pg_config.clone(), querystring, parameters, "cur_name".into(), diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 4f38407d..24780a6a 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,18 +1,15 @@ use crate::runtime::tokio_runtime; -use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; +use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; -use std::{sync::Arc, vec}; +use std::sync::Arc; use tokio_postgres::Config; -use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, - query_result::PSQLDriverPyQueryResult, - value_converter::{convert_parameters, PythonDTO, QueryParameter}, -}; +use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; use super::{ common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, - connection::{Connection, PsqlpyConnection}, + connection::Connection, + inner_connection::PsqlpyConnection, listener::core::Listener, utils::{build_connection_config, build_manager, build_tls}, }; @@ -138,10 +135,10 @@ pub fn connect( let pool = db_pool_builder.build()?; Ok(ConnectionPool { - pool, - pg_config, - ca_file, - ssl_mode, + pool: pool, + pg_config: Arc::new(pg_config), + ca_file: ca_file, + ssl_mode: ssl_mode, }) } @@ -207,7 +204,7 @@ impl ConnectionPoolStatus { #[pyclass(subclass)] pub struct ConnectionPool { pool: Pool, - pg_config: Config, + pg_config: Arc, ca_file: Option, ssl_mode: Option, } @@ -221,10 +218,10 @@ impl ConnectionPool { ssl_mode: Option, ) -> Self { ConnectionPool { - pool, - pg_config, - ca_file, - ssl_mode, + pool: pool, + pg_config: Arc::new(pg_config), + ca_file: ca_file, + ssl_mode: ssl_mode, } } } @@ -361,144 +358,9 @@ impl ConnectionPool { self.pool.resize(new_max_size); } - /// Execute querystring with parameters. - /// - /// Prepare statement and cache it, then execute. - /// - /// # Errors - /// May return Err Result if cannot retrieve new connection - /// or prepare statement or execute statement. - #[pyo3(signature = (querystring, parameters=None, prepared=None))] - pub async fn execute<'a>( - self_: pyo3::Py, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); - - let db_pool_manager = tokio_runtime() - .spawn(async move { Ok::(db_pool.get().await?) }) - .await??; - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - let result = if prepared { - tokio_runtime() - .spawn(async move { - db_pool_manager - .query( - &db_pool_manager.prepare_cached(&querystring).await?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement from ConnectionPool, error - {err}" - )) - }) - }) - .await?? - } else { - tokio_runtime() - .spawn(async move { - db_pool_manager - .query( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement from ConnectionPool, error - {err}" - )) - }) - }) - .await?? - }; - Ok(PSQLDriverPyQueryResult::new(result)) - } - - /// Fetch result from the database. - /// - /// It's the same as `execute`, we made it for people who prefer - /// `fetch()`. - /// - /// Prepare statement and cache it, then execute. - /// - /// # Errors - /// May return Err Result if cannot retrieve new connection - /// or prepare statement or execute statement. - #[pyo3(signature = (querystring, parameters=None, prepared=None))] - pub async fn fetch<'a>( - self_: pyo3::Py, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); - - let db_pool_manager = tokio_runtime() - .spawn(async move { Ok::(db_pool.get().await?) }) - .await??; - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - let prepared = prepared.unwrap_or(true); - let result = if prepared { - tokio_runtime() - .spawn(async move { - db_pool_manager - .query( - &db_pool_manager.prepare_cached(&querystring).await?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement from ConnectionPool, error - {err}" - )) - }) - }) - .await?? - } else { - tokio_runtime() - .spawn(async move { - db_pool_manager - .query( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>(), - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement from ConnectionPool, error - {err}" - )) - }) - }) - .await?? - }; - Ok(PSQLDriverPyQueryResult::new(result)) - } - #[must_use] pub fn acquire(&self) -> Connection { - Connection::new(None, Some(self.pool.clone())) + Connection::new(None, Some(self.pool.clone()), self.pg_config.clone()) } #[must_use] @@ -521,7 +383,10 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub async fn connection(self_: pyo3::Py) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); + let (db_pool, pg_config) = pyo3::Python::with_gil(|gil| { + let slf = self_.borrow(gil); + (slf.pool.clone(), slf.pg_config.clone()) + }); let db_connection = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) @@ -531,6 +396,7 @@ impl ConnectionPool { Ok(Connection::new( Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))), None, + pg_config, )) } diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 7368d29a..f391d1c1 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -1,17 +1,17 @@ -use std::sync::Arc; +use std::{net::IpAddr, sync::Arc}; use pyo3::{ exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, }; +use tokio_postgres::{config::Host, Config}; use crate::{ - common::ObjectQueryTrait, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, query_result::PSQLDriverPyQueryResult, runtime::rustdriver_future, }; -use super::connection::PsqlpyConnection; +use super::inner_connection::PsqlpyConnection; /// Additional implementation for the `Object` type. #[allow(clippy::ref_option)] @@ -55,7 +55,7 @@ impl CursorObjectTrait for PsqlpyConnection { cursor_init_query.push_str(format!(" CURSOR FOR {querystring}").as_str()); - self.psqlpy_query(cursor_init_query, parameters.clone(), *prepared) + self.execute(cursor_init_query, parameters.clone(), *prepared) .await .map_err(|err| { RustPSQLDriverError::CursorStartError(format!("Cannot start cursor, error - {err}")) @@ -77,7 +77,7 @@ impl CursorObjectTrait for PsqlpyConnection { )); } - self.psqlpy_query( + self.execute( format!("CLOSE {cursor_name}"), Option::default(), Some(false), @@ -91,6 +91,7 @@ impl CursorObjectTrait for PsqlpyConnection { #[pyclass(subclass)] pub struct Cursor { db_transaction: Option>, + pg_config: Arc, querystring: String, parameters: Option>, cursor_name: String, @@ -105,6 +106,7 @@ impl Cursor { #[must_use] pub fn new( db_transaction: Arc, + pg_config: Arc, querystring: String, parameters: Option>, cursor_name: String, @@ -114,6 +116,7 @@ impl Cursor { ) -> Self { Cursor { db_transaction: Some(db_transaction), + pg_config, querystring, parameters, cursor_name, @@ -128,6 +131,98 @@ impl Cursor { #[pymethods] impl Cursor { + #[getter] + fn conn_dbname(&self) -> Option<&str> { + self.pg_config.get_dbname() + } + + #[getter] + fn user(&self) -> Option<&str> { + self.pg_config.get_user() + } + + #[getter] + fn host_addrs(&self) -> Vec { + let mut host_addrs_vec = vec![]; + + let host_addrs = self.pg_config.get_hostaddrs(); + for ip_addr in host_addrs { + match ip_addr { + IpAddr::V4(ipv4) => { + host_addrs_vec.push(ipv4.to_string()); + } + IpAddr::V6(ipv6) => { + host_addrs_vec.push(ipv6.to_string()); + } + } + } + + host_addrs_vec + } + + #[cfg(unix)] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + Host::Unix(host) => { + hosts_vec.push(host.display().to_string()); + } + } + } + + hosts_vec + } + + #[cfg(not(unix))] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + _ => unreachable!(), + } + } + + hosts_vec + } + + #[getter] + fn ports(&self) -> Vec<&u16> { + return self.pg_config.get_ports().iter().collect::>(); + } + + #[getter] + fn cursor_name(&self) -> String { + return self.cursor_name.clone(); + } + + #[getter] + fn querystring(&self) -> String { + return self.querystring.clone(); + } + + #[getter] + fn parameters(&self) -> Option> { + return self.parameters.clone(); + } + + #[getter] + fn prepared(&self) -> Option { + return self.prepared.clone(); + } + #[must_use] fn __aiter__(slf: Py) -> Py { slf @@ -220,7 +315,7 @@ impl Cursor { rustdriver_future(gil, async move { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH {fetch_number} FROM {cursor_name}"), None, Some(false), @@ -318,7 +413,7 @@ impl Cursor { }; let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH {fetch_number} FROM {cursor_name}"), None, Some(false), @@ -350,7 +445,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query(format!("FETCH NEXT FROM {cursor_name}"), None, Some(false)) + .execute(format!("FETCH NEXT FROM {cursor_name}"), None, Some(false)) .await .map_err(|err| { RustPSQLDriverError::CursorFetchError(format!( @@ -377,7 +472,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query(format!("FETCH PRIOR FROM {cursor_name}"), None, Some(false)) + .execute(format!("FETCH PRIOR FROM {cursor_name}"), None, Some(false)) .await .map_err(|err| { RustPSQLDriverError::CursorFetchError(format!( @@ -404,7 +499,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query(format!("FETCH FIRST FROM {cursor_name}"), None, Some(false)) + .execute(format!("FETCH FIRST FROM {cursor_name}"), None, Some(false)) .await .map_err(|err| { RustPSQLDriverError::CursorFetchError(format!( @@ -431,7 +526,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query(format!("FETCH LAST FROM {cursor_name}"), None, Some(false)) + .execute(format!("FETCH LAST FROM {cursor_name}"), None, Some(false)) .await .map_err(|err| { RustPSQLDriverError::CursorFetchError(format!( @@ -461,7 +556,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH ABSOLUTE {absolute_number} FROM {cursor_name}"), None, Some(false), @@ -495,7 +590,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH RELATIVE {relative_number} FROM {cursor_name}"), None, Some(false), @@ -528,7 +623,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH FORWARD ALL FROM {cursor_name}"), None, Some(false), @@ -562,7 +657,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH BACKWARD {backward_count} FROM {cursor_name}",), None, Some(false), @@ -595,7 +690,7 @@ impl Cursor { if let Some(db_transaction) = db_transaction { let result = db_transaction - .psqlpy_query( + .execute( format!("FETCH BACKWARD ALL FROM {cursor_name}"), None, Some(false), diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs new file mode 100644 index 00000000..c66006cc --- /dev/null +++ b/src/driver/inner_connection.rs @@ -0,0 +1,267 @@ +use bytes::Buf; +use deadpool_postgres::Object; +use postgres_types::ToSql; +use pyo3::{Py, PyAny, Python}; +use std::vec; +use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement}; + +use crate::{ + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, + value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter}, +}; + +#[allow(clippy::module_name_repetitions)] +pub enum PsqlpyConnection { + PoolConn(Object), + SingleConn(Client), +} + +impl PsqlpyConnection { + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot prepare statement. + pub async fn prepare_cached(&self, query: &str) -> RustPSQLDriverPyResult { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), + } + } + + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> RustPSQLDriverPyResult> + where + T: ?Sized + ToStatement, + { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.query(statement, params).await?), + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.query(statement, params).await?) + } + } + } + + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), + } + } + + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> RustPSQLDriverPyResult + where + T: ?Sized + ToStatement, + { + match self { + PsqlpyConnection::PoolConn(pconn) => { + return Ok(pconn.query_one(statement, params).await?) + } + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.query_one(statement, params).await?) + } + } + } + + pub async fn execute( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> RustPSQLDriverPyResult { + let prepared = prepared.unwrap_or(true); + + let mut params: Vec = vec![]; + if let Some(parameters) = parameters { + params = convert_parameters(parameters)?; + } + + let boxed_params = ¶ms + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice(); + + let result = if prepared { + self.query( + &self.prepare_cached(&querystring).await.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement, error - {err}" + )) + })?, + boxed_params, + ) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? + } else { + self.query(&querystring, boxed_params) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? + }; + + Ok(PSQLDriverPyQueryResult::new(result)) + } + + pub async fn execute_many( + &self, + querystring: String, + parameters: Option>>, + prepared: Option, + ) -> RustPSQLDriverPyResult<()> { + let prepared = prepared.unwrap_or(true); + + let mut params: Vec> = vec![]; + if let Some(parameters) = parameters { + for vec_of_py_any in parameters { + params.push(convert_parameters(vec_of_py_any)?); + } + } + + for param in params { + let boxed_params = ¶m + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice(); + + let querystring_result = if prepared { + let prepared_stmt = &self.prepare_cached(&querystring).await; + if let Err(error) = prepared_stmt { + return Err(RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement in execute_many, operation rolled back {error}", + ))); + } + self.query(&self.prepare_cached(&querystring).await?, boxed_params) + .await + } else { + self.query(&querystring, boxed_params).await + }; + + if let Err(error) = querystring_result { + return Err(RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occured in `execute_many` statement: {error}" + ))); + } + } + + return Ok(()); + } + + pub async fn fetch_row_raw( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> RustPSQLDriverPyResult { + let prepared = prepared.unwrap_or(true); + + let mut params: Vec = vec![]; + if let Some(parameters) = parameters { + params = convert_parameters(parameters)?; + } + + let boxed_params = ¶ms + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice(); + + let result = if prepared { + self.query_one( + &self.prepare_cached(&querystring).await.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement, error - {err}" + )) + })?, + boxed_params, + ) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? + } else { + self.query_one(&querystring, boxed_params) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? + }; + + return Ok(result); + } + + pub async fn fetch_row( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> RustPSQLDriverPyResult { + let result = self + .fetch_row_raw(querystring, parameters, prepared) + .await?; + + return Ok(PSQLDriverSinglePyQueryResult::new(result)); + } + + pub async fn fetch_val( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> RustPSQLDriverPyResult> { + let result = self + .fetch_row_raw(querystring, parameters, prepared) + .await?; + + return Python::with_gil(|gil| match result.columns().first() { + Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None), + None => Ok(gil.None()), + }); + } + + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute copy data. + pub async fn copy_in(&self, statement: &T) -> RustPSQLDriverPyResult> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send, + { + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.copy_in(statement).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), + } + } +} diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index c8fd271c..83aa9b3e 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -14,7 +14,8 @@ use tokio_postgres::{AsyncMessage, Config}; use crate::{ driver::{ common_options::SslMode, - connection::{Connection, PsqlpyConnection}, + connection::Connection, + inner_connection::PsqlpyConnection, utils::{build_tls, is_coroutine_function, ConfiguredTLS}, }, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -27,7 +28,7 @@ use super::structs::{ #[pyclass] pub struct Listener { - pg_config: Config, + pg_config: Arc, ca_file: Option, ssl_mode: Option, channel_callbacks: Arc>, @@ -41,14 +42,14 @@ pub struct Listener { impl Listener { #[must_use] - pub fn new(pg_config: Config, ca_file: Option, ssl_mode: Option) -> Self { + pub fn new(pg_config: Arc, ca_file: Option, ssl_mode: Option) -> Self { Listener { - pg_config, + pg_config: pg_config.clone(), ca_file, ssl_mode, channel_callbacks: Arc::default(), listen_abort_handler: Option::default(), - connection: Connection::new(None, None), + connection: Connection::new(None, None, pg_config.clone()), receiver: Option::default(), listen_query: Arc::default(), is_listened: Arc::new(RwLock::new(false)), @@ -217,8 +218,11 @@ impl Listener { tokio_runtime().spawn(connection); self.receiver = Some(Arc::new(RwLock::new(receiver))); - self.connection = - Connection::new(Some(Arc::new(PsqlpyConnection::SingleConn(client))), None); + self.connection = Connection::new( + Some(Arc::new(PsqlpyConnection::SingleConn(client))), + None, + self.pg_config.clone(), + ); self.is_started = true; diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 578bf2cd..e7827cd5 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -3,6 +3,7 @@ pub mod connection; pub mod connection_pool; pub mod connection_pool_builder; pub mod cursor; +pub mod inner_connection; pub mod listener; pub mod transaction; pub mod transaction_options; diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 3fa59e4d..2fa38ba5 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -6,22 +6,20 @@ use pyo3::{ pyclass, types::{PyList, PyTuple}, }; -use tokio_postgres::binary_copy::BinaryCopyInWriter; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter}, }; use super::{ - connection::PsqlpyConnection, cursor::Cursor, + inner_connection::PsqlpyConnection, transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; -use crate::common::ObjectQueryTrait; -use std::{collections::HashSet, sync::Arc}; +use std::{collections::HashSet, net::IpAddr, sync::Arc}; #[allow(clippy::module_name_repetitions)] pub trait TransactionObjectTrait { @@ -107,6 +105,7 @@ impl TransactionObjectTrait for PsqlpyConnection { #[pyclass(subclass)] pub struct Transaction { pub db_client: Option>, + pg_config: Arc, is_started: bool, is_done: bool, @@ -123,6 +122,7 @@ impl Transaction { #[must_use] pub fn new( db_client: Arc, + pg_config: Arc, is_started: bool, is_done: bool, isolation_level: Option, @@ -133,6 +133,7 @@ impl Transaction { ) -> Self { Self { db_client: Some(db_client), + pg_config, is_started, is_done, isolation_level, @@ -160,6 +161,78 @@ impl Transaction { #[pymethods] impl Transaction { + #[getter] + fn conn_dbname(&self) -> Option<&str> { + self.pg_config.get_dbname() + } + + #[getter] + fn user(&self) -> Option<&str> { + self.pg_config.get_user() + } + + #[getter] + fn host_addrs(&self) -> Vec { + let mut host_addrs_vec = vec![]; + + let host_addrs = self.pg_config.get_hostaddrs(); + for ip_addr in host_addrs { + match ip_addr { + IpAddr::V4(ipv4) => { + host_addrs_vec.push(ipv4.to_string()); + } + IpAddr::V6(ipv6) => { + host_addrs_vec.push(ipv6.to_string()); + } + } + } + + host_addrs_vec + } + + #[cfg(unix)] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + Host::Unix(host) => { + hosts_vec.push(host.display().to_string()); + } + } + } + + hosts_vec + } + + #[cfg(not(unix))] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + _ => unreachable!(), + } + } + + hosts_vec + } + + #[getter] + fn ports(&self) -> Vec<&u16> { + return self.pg_config.get_ports().iter().collect::>(); + } + #[must_use] pub fn __aiter__(self_: Py) -> Py { self_ @@ -328,9 +401,7 @@ impl Transaction { }); is_transaction_ready?; if let Some(db_client) = db_client { - return db_client - .psqlpy_query(querystring, parameters, prepared) - .await; + return db_client.execute(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::TransactionClosedError) @@ -384,9 +455,7 @@ impl Transaction { }); is_transaction_ready?; if let Some(db_client) = db_client { - return db_client - .psqlpy_query(querystring, parameters, prepared) - .await; + return db_client.execute(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::TransactionClosedError) @@ -420,36 +489,7 @@ impl Transaction { is_transaction_ready?; if let Some(db_client) = db_client { - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - - let result = if prepared.unwrap_or(true) { - db_client - .query_one( - &db_client.prepare_cached(&querystring).await?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - } else { - db_client - .query_one( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - }; - - return Ok(PSQLDriverSinglePyQueryResult::new(result)); + return db_client.fetch_row(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::TransactionClosedError) @@ -476,41 +516,9 @@ impl Transaction { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) }); + is_transaction_ready?; if let Some(db_client) = db_client { - is_transaction_ready?; - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } - - let result = if prepared.unwrap_or(true) { - db_client - .query_one( - &db_client.prepare_cached(&querystring).await?, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - } else { - db_client - .query_one( - &querystring, - ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await? - }; - - return Python::with_gil(|gil| match result.columns().first() { - Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None), - None => Ok(gil.None()), - }); + return db_client.fetch_val(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::TransactionClosedError) @@ -537,51 +545,11 @@ impl Transaction { (self_.check_is_transaction_ready(), self_.db_client.clone()) }); + is_transaction_ready?; if let Some(db_client) = db_client { - is_transaction_ready?; - - let mut params: Vec> = vec![]; - if let Some(parameters) = parameters { - for vec_of_py_any in parameters { - params.push(convert_parameters(vec_of_py_any)?); - } - } - let prepared = prepared.unwrap_or(true); - - for param in params { - let is_query_result_ok = if prepared { - let prepared_stmt = &db_client.prepare_cached(&querystring).await; - if let Err(error) = prepared_stmt { - return Err(RustPSQLDriverError::TransactionExecuteError(format!( - "Cannot prepare statement in execute_many, operation rolled back {error}", - ))); - } - db_client - .query( - &db_client.prepare_cached(&querystring).await?, - ¶m - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - } else { - db_client - .query( - &querystring, - ¶m - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(), - ) - .await - }; - is_query_result_ok?; - } - - return Ok(()); + return db_client + .execute_many(querystring, parameters, prepared) + .await; } Err(RustPSQLDriverError::TransactionClosedError) @@ -804,9 +772,9 @@ impl Transaction { (self_.check_is_transaction_ready(), self_.db_client.clone()) }); - if let Some(db_client) = db_client { - is_transaction_ready?; + is_transaction_ready?; + if let Some(db_client) = db_client { let mut futures = vec![]; if let Some(queries) = queries { let gil_result = pyo3::Python::with_gil(|gil| -> PyResult<()> { @@ -822,7 +790,7 @@ impl Transaction { Ok(param) => Some(param.into()), Err(_) => None, }; - futures.push(db_client.psqlpy_query(querystring, params, prepared)); + futures.push(db_client.execute(querystring, params, prepared)); } Ok(()) }); @@ -863,6 +831,7 @@ impl Transaction { if let Some(db_client) = &self.db_client { return Ok(Cursor::new( db_client.clone(), + self.pg_config.clone(), querystring, parameters, "cur_name".into(),