diff --git a/Cargo.lock b/Cargo.lock index 85f93be3..fee82b45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -707,9 +707,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" [[package]] name = "openssl" @@ -997,7 +997,7 @@ dependencies = [ [[package]] name = "psqlpy" -version = "0.9.2" +version = "0.9.3" dependencies = [ "byteorder", "bytes", @@ -1010,6 +1010,7 @@ dependencies = [ "geo-types", "itertools", "macaddr", + "once_cell", "openssl", "openssl-src", "openssl-sys", @@ -1021,6 +1022,7 @@ dependencies = [ "postgres_array", "pyo3", "pyo3-async-runtimes", + "regex", "rust_decimal 1.36.0", "serde", "serde_json", @@ -1192,9 +1194,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -1204,9 +1206,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -1215,9 +1217,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rend" diff --git a/Cargo.toml b/Cargo.toml index 710c59cf..1846f8c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "psqlpy" -version = "0.9.2" +version = "0.9.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -54,3 +54,5 @@ pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = ] } futures-channel = "0.3.31" futures = "0.3.31" +regex = "1.11.1" +once_cell = "1.20.3" diff --git a/docs/.vuepress/sidebar.ts b/docs/.vuepress/sidebar.ts index 94b082f9..d3afac19 100644 --- a/docs/.vuepress/sidebar.ts +++ b/docs/.vuepress/sidebar.ts @@ -33,6 +33,7 @@ export default sidebar({ prefix: "usage/", collapsible: true, children: [ + "parameters", { text: "Types", prefix: "types/", diff --git a/docs/usage/parameters.md b/docs/usage/parameters.md new file mode 100644 index 00000000..596fbbc5 --- /dev/null +++ b/docs/usage/parameters.md @@ -0,0 +1,42 @@ +--- +title: Passing parameters to SQL queries +--- + +We support two variant of passing parameters to sql queries. + +::: tabs +@tab Parameters sequence + +You can pass parameters as some python Sequence. + +Placeholders in querystring must be marked as `$1`, `$2` and so on, +depending on how many parameters you have. + +```python +async def main(): + ... + + await connection.execute( + querystring="SELECT * FROM users WHERE id = $1", + parameters=(101,), + ) +``` + +@tab Parameters mapping + +If you prefer use named arguments, we support it too. +Placeholder in querystring must look like `$(parameter)p`. + +If you don't pass parameter but have it in querystring, exception will be raised. + +```python +async def main(): + ... + + await connection.execute( + querystring="SELECT * FROM users WHERE id = $(user_id)p", + parameters=dict(user_id=101), + ) +``` + +::: \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bc906612..84c00f42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ warn_unused_ignores = false [tool.ruff] fix = true unsafe-fixes = true -line-length = 120 +line-length = 89 exclude = [".venv/", "psqlpy-stress"] [tool.ruff.format] diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 77ec440d..8c391d96 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -2,9 +2,9 @@ import types from enum import Enum from io import BytesIO from ipaddress import IPv4Address, IPv6Address -from typing import Any, Awaitable, Callable, Sequence, TypeVar +from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar -from typing_extensions import Buffer, Self +from typing_extensions import Buffer, Self, TypeAlias _CustomClass = TypeVar( "_CustomClass", @@ -13,6 +13,8 @@ _RowFactoryRV = TypeVar( "_RowFactoryRV", ) +ParamsT: TypeAlias = Sequence[Any] | Mapping[str, Any] | None + class QueryResult: """Result.""" @@ -150,7 +152,7 @@ class SingleQueryResult: class SynchronousCommit(Enum): """ - Class for synchronous_commit option for transactions. + Synchronous_commit option for transactions. ### Variants: - `On`: The meaning may change based on whether you have @@ -181,7 +183,7 @@ class SynchronousCommit(Enum): RemoteApply = 5 class IsolationLevel(Enum): - """Class for Isolation Level for transactions.""" + """Isolation Level for transactions.""" ReadUncommitted = 1 ReadCommitted = 2 @@ -290,7 +292,7 @@ class Cursor: cursor_name: str querystring: str - parameters: Sequence[Any] + parameters: ParamsT = None prepared: bool | None conn_dbname: str | None user: str | None @@ -464,7 +466,7 @@ class Transaction: async def execute( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> QueryResult: """Execute the query. @@ -554,7 +556,7 @@ class Transaction: async def fetch( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> QueryResult: """Fetch the result from database. @@ -574,7 +576,7 @@ class Transaction: async def fetch_row( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> SingleQueryResult: """Fetch exaclty single row from query. @@ -613,7 +615,7 @@ class Transaction: async def fetch_val( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> Any | None: """Execute the query and return first value of the first row. @@ -814,7 +816,7 @@ class Transaction: def cursor( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, fetch_number: int | None = None, scroll: bool | None = None, prepared: bool = True, @@ -906,7 +908,7 @@ class Connection: async def execute( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> QueryResult: """Execute the query. @@ -990,7 +992,7 @@ class Connection: async def fetch( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> QueryResult: """Fetch the result from database. @@ -1010,7 +1012,7 @@ class Connection: async def fetch_row( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> SingleQueryResult: """Fetch exaclty single row from query. @@ -1046,7 +1048,7 @@ class Connection: async def fetch_val( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, prepared: bool = True, ) -> Any: """Execute the query and return first value of the first row. @@ -1100,7 +1102,7 @@ class Connection: def cursor( self: Self, querystring: str, - parameters: Sequence[Any] | None = None, + parameters: ParamsT = None, fetch_number: int | None = None, scroll: bool | None = None, prepared: bool = True, @@ -1708,10 +1710,13 @@ class ConnectionPoolBuilder: self: Self, keepalives_retries: int, ) -> Self: - """ - Set the maximum number of TCP keepalive probes that will be sent before dropping a connection. + """Keepalives Retries. + + Set the maximum number of TCP keepalive probes + that will be sent before dropping a connection. - This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + This is ignored for Unix domain sockets, + or if the `keepalives` option is disabled. ### Parameters: - `keepalives_retries`: number of retries. diff --git a/python/psqlpy/_internal/extra_types.pyi b/python/psqlpy/_internal/extra_types.pyi index 93037639..e29c7573 100644 --- a/python/psqlpy/_internal/extra_types.pyi +++ b/python/psqlpy/_internal/extra_types.pyi @@ -143,9 +143,13 @@ class MacAddr8: class CustomType: def __init__(self, value: bytes) -> None: ... -Coordinates: TypeAlias = list[int | float] | set[int | float] | tuple[int | float, int | float] +Coordinates: TypeAlias = ( + list[int | float] | set[int | float] | tuple[int | float, int | float] +) PairsOfCoordinates: TypeAlias = ( - list[Coordinates | int | float] | set[Coordinates | int | float] | tuple[Coordinates | int | float, ...] + list[Coordinates | int | float] + | set[Coordinates | int | float] + | tuple[Coordinates | int | float, ...] ) class Point: @@ -227,7 +231,9 @@ class Circle: def __init__( self: Self, - value: list[int | float] | set[int | float] | tuple[int | float, int | float, int | float], + value: list[int | float] + | set[int | float] + | tuple[int | float, int | float, int | float], ) -> None: """Create new instance of Circle. @@ -374,7 +380,11 @@ class IpAddressArray: def __init__( self: Self, inner: typing.Sequence[ - IPv4Address | IPv6Address | typing.Sequence[IPv4Address] | typing.Sequence[IPv6Address] | typing.Any, + IPv4Address + | IPv6Address + | typing.Sequence[IPv4Address] + | typing.Sequence[IPv6Address] + | typing.Any, ], ) -> None: """Create new instance of IpAddressArray. diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 4a388f62..30426e5f 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -73,6 +73,11 @@ def listener_table_name() -> str: return random_string() +@pytest.fixture +def map_parameters_table_name() -> str: + return random_string() + + @pytest.fixture def number_database_records() -> int: return random.randint(10, 35) @@ -161,6 +166,23 @@ async def create_table_for_listener_tests( ) +@pytest.fixture +async def create_table_for_map_parameters_test( + psql_pool: ConnectionPool, + map_parameters_table_name: str, +) -> AsyncGenerator[None, None]: + connection = await psql_pool.connection() + await connection.execute( + f"CREATE TABLE {map_parameters_table_name}" + "(id SERIAL, name VARCHAR(255),surname VARCHAR(255), age INT)", + ) + + yield + await connection.execute( + f"DROP TABLE {map_parameters_table_name}", + ) + + @pytest.fixture async def test_cursor( psql_pool: ConnectionPool, diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 898cc405..7af208f2 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -183,7 +183,10 @@ async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: connection = await psql_pool.connection() await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch") await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch2") - query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);" + query = ( + "CREATE TABLE execute_batch (name VARCHAR);" + "CREATE TABLE execute_batch2 (name VARCHAR);" + ) async with psql_pool.acquire() as conn: await conn.execute_batch(querystring=query) await conn.execute(querystring="SELECT * FROM execute_batch") diff --git a/python/tests/test_kwargs_parameters.py b/python/tests/test_kwargs_parameters.py new file mode 100644 index 00000000..d1fb1ebf --- /dev/null +++ b/python/tests/test_kwargs_parameters.py @@ -0,0 +1,73 @@ +import pytest +from psqlpy import ConnectionPool +from psqlpy.exceptions import ( + PyToRustValueMappingError, +) + +pytestmark = pytest.mark.anyio + + +async def test_success_default_map_parameters( + psql_pool: ConnectionPool, + table_name: str, +) -> None: + async with psql_pool.acquire() as conn: + exist_records = await conn.execute( + f"SELECT * FROM {table_name}", + ) + result = exist_records.result() + + test_fetch = await conn.execute( + f"SELECT * FROM {table_name} WHERE id = $(id_)p", + parameters={ + "id_": result[0]["id"], + }, + ) + + assert test_fetch.result()[0]["id"] == result[0]["id"] + assert test_fetch.result()[0]["name"] == result[0]["name"] + + +@pytest.mark.usefixtures("create_table_for_map_parameters_test") +async def test_success_multiple_same_parameters( + psql_pool: ConnectionPool, + map_parameters_table_name: str, +) -> None: + test_name_surname = "Surname" + test_age = 1 + async with psql_pool.acquire() as conn: + await conn.execute( + querystring=( + f"INSERT INTO {map_parameters_table_name} " + "(name, surname, age) VALUES ($(name)p, $(surname)p, $(age)p)" + ), + parameters={ + "name": test_name_surname, + "surname": test_name_surname, + "age": test_age, + }, + ) + + res = await conn.execute( + querystring=( + f"SELECT * FROM {map_parameters_table_name} " + "WHERE name = $(name)p OR surname = $(name)p" + ), + parameters={"name": test_name_surname}, + ) + + assert res.result()[0]["name"] == test_name_surname + assert res.result()[0]["surname"] == test_name_surname + assert res.result()[0]["age"] == test_age + + +async def test_failed_no_parameter( + psql_pool: ConnectionPool, + table_name: str, +) -> None: + async with psql_pool.acquire() as conn: + with pytest.raises(expected_exception=PyToRustValueMappingError): + await conn.execute( + querystring=(f"SELECT * FROM {table_name} " "WHERE name = $(name)p"), # noqa: ISC001 + parameters={"mistake": "wow"}, + ) diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py index c48c8974..a1ff0742 100644 --- a/python/tests/test_listener.py +++ b/python/tests/test_listener.py @@ -126,7 +126,7 @@ async def test_listener_listen( listener_table_name=listener_table_name, ) - listener.abort_listen() + await listener.shutdown() @pytest.mark.usefixtures("create_table_for_listener_tests") @@ -152,7 +152,7 @@ async def test_listener_asynciterator( assert listener_msg.payload == TEST_PAYLOAD break - listener.abort_listen() + await listener.shutdown() @pytest.mark.usefixtures("create_table_for_listener_tests") @@ -175,7 +175,7 @@ async def test_listener_abort( listener_table_name=listener_table_name, ) - listener.abort_listen() + await listener.shutdown() await clear_test_table( psql_pool=psql_pool, @@ -261,7 +261,7 @@ async def test_listener_more_than_one_callback( assert data_result["channel"] == additional_channel - listener.abort_listen() + await listener.shutdown() @pytest.mark.usefixtures("create_table_for_listener_tests") @@ -290,7 +290,7 @@ async def test_listener_clear_callbacks( is_insert_exist=False, ) - listener.abort_listen() + await listener.shutdown() @pytest.mark.usefixtures("create_table_for_listener_tests") @@ -317,4 +317,4 @@ async def test_listener_clear_all_callbacks( is_insert_exist=False, ) - listener.abort_listen() + await listener.shutdown() diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index 151f5bb5..3c60676a 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -352,7 +352,10 @@ async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: connection = await psql_pool.connection() await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch") await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch2") - query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);" + query = ( + "CREATE TABLE execute_batch (name VARCHAR);" + "CREATE TABLE execute_batch2 (name VARCHAR);" + ) async with connection.transaction() as transaction: await transaction.execute_batch(querystring=query) await transaction.execute(querystring="SELECT * FROM execute_batch") @@ -377,7 +380,9 @@ async def test_synchronous_commit( table_name: str, number_database_records: int, ) -> None: - async with psql_pool.acquire() as conn, conn.transaction(synchronous_commit=synchronous_commit) as trans: + async with psql_pool.acquire() as conn, conn.transaction( + synchronous_commit=synchronous_commit, + ) as trans: res = await trans.execute( f"SELECT * FROM {table_name}", ) diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index de62c554..34361b22 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -134,7 +134,7 @@ async def test_as_class( ("TEXT", "Some String", "Some String"), ( "XML", - """Manual...""", + """Manual...""", # noqa: E501 """Manual...""", ), ("BOOL", True, True), @@ -151,7 +151,11 @@ async def test_as_class( ("TIME", now_datetime.time(), now_datetime.time()), ("TIMESTAMP", now_datetime, now_datetime), ("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz), - ("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta), + ( + "TIMESTAMPTZ", + now_datetime_with_tz_in_asia_jakarta, + now_datetime_with_tz_in_asia_jakarta, + ), ("UUID", uuid_, str(uuid_)), ("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")), ( @@ -246,7 +250,11 @@ async def test_as_class( ("INT2 ARRAY", [SmallInt(12), SmallInt(100)], [12, 100]), ("INT2 ARRAY", [[SmallInt(12)], [SmallInt(100)]], [[12], [100]]), ("INT4 ARRAY", [Integer(121231231), Integer(121231231)], [121231231, 121231231]), - ("INT4 ARRAY", [[Integer(121231231)], [Integer(121231231)]], [[121231231], [121231231]]), + ( + "INT4 ARRAY", + [[Integer(121231231)], [Integer(121231231)]], + [[121231231], [121231231]], + ), ( "INT8 ARRAY", [BigInt(99999999999999999), BigInt(99999999999999999)], @@ -308,7 +316,11 @@ async def test_as_class( [[now_datetime.time()], [now_datetime.time()]], ), ("TIMESTAMP ARRAY", [now_datetime, now_datetime], [now_datetime, now_datetime]), - ("TIMESTAMP ARRAY", [[now_datetime], [now_datetime]], [[now_datetime], [now_datetime]]), + ( + "TIMESTAMP ARRAY", + [[now_datetime], [now_datetime]], + [[now_datetime], [now_datetime]], + ), ( "TIMESTAMPTZ ARRAY", [now_datetime_with_tz, now_datetime_with_tz], @@ -638,8 +650,14 @@ async def test_as_class( ), ( "INTERVAL ARRAY", - [datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100)], - [datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100)], + [ + datetime.timedelta(days=100, microseconds=100), + datetime.timedelta(days=100, microseconds=100), + ], + [ + datetime.timedelta(days=100, microseconds=100), + datetime.timedelta(days=100, microseconds=100), + ], ), ], ) @@ -681,7 +699,9 @@ async def test_deserialization_composite_into_python( await connection.execute("DROP TYPE IF EXISTS inner_type") await connection.execute("DROP TYPE IF EXISTS enum_type") await connection.execute("CREATE TYPE enum_type AS ENUM ('sad', 'ok', 'happy')") - await connection.execute("CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)") + await connection.execute( + "CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)", + ) create_type_query = """ CREATE type all_types AS ( bytea_ BYTEA, @@ -1082,7 +1102,12 @@ async def test_empty_array( async with psql_pool.acquire() as conn: await conn.execute("DROP TABLE IF EXISTS test_earr") await conn.execute( - "CREATE TABLE test_earr (id serial NOT NULL PRIMARY KEY, e_array text[] NOT NULL DEFAULT array[]::text[])", + """ + CREATE TABLE test_earr ( + id serial NOT NULL PRIMARY KEY, + e_array text[] NOT NULL DEFAULT array[]::text[] + ) + """, ) await conn.execute("INSERT INTO test_earr(id) VALUES(2);") @@ -1125,8 +1150,16 @@ async def test_empty_array( ("INT2 ARRAY", Int16Array([]), []), ("INT2 ARRAY", Int16Array([SmallInt(12), SmallInt(100)]), [12, 100]), ("INT2 ARRAY", Int16Array([[SmallInt(12)], [SmallInt(100)]]), [[12], [100]]), - ("INT4 ARRAY", Int32Array([Integer(121231231), Integer(121231231)]), [121231231, 121231231]), - ("INT4 ARRAY", Int32Array([[Integer(121231231)], [Integer(121231231)]]), [[121231231], [121231231]]), + ( + "INT4 ARRAY", + Int32Array([Integer(121231231), Integer(121231231)]), + [121231231, 121231231], + ), + ( + "INT4 ARRAY", + Int32Array([[Integer(121231231)], [Integer(121231231)]]), + [[121231231], [121231231]], + ), ( "INT8 ARRAY", Int64Array([BigInt(99999999999999999), BigInt(99999999999999999)]), @@ -1187,8 +1220,16 @@ async def test_empty_array( TimeArray([[now_datetime.time()], [now_datetime.time()]]), [[now_datetime.time()], [now_datetime.time()]], ), - ("TIMESTAMP ARRAY", DateTimeArray([now_datetime, now_datetime]), [now_datetime, now_datetime]), - ("TIMESTAMP ARRAY", DateTimeArray([[now_datetime], [now_datetime]]), [[now_datetime], [now_datetime]]), + ( + "TIMESTAMP ARRAY", + DateTimeArray([now_datetime, now_datetime]), + [now_datetime, now_datetime], + ), + ( + "TIMESTAMP ARRAY", + DateTimeArray([[now_datetime], [now_datetime]]), + [[now_datetime], [now_datetime]], + ), ( "TIMESTAMPTZ ARRAY", DateTimeTZArray([now_datetime_with_tz, now_datetime_with_tz]), @@ -1554,9 +1595,15 @@ async def test_empty_array( ( "INTERVAL ARRAY", IntervalArray( - [[datetime.timedelta(days=100, microseconds=100)], [datetime.timedelta(days=100, microseconds=100)]], + [ + [datetime.timedelta(days=100, microseconds=100)], + [datetime.timedelta(days=100, microseconds=100)], + ], ), - [[datetime.timedelta(days=100, microseconds=100)], [datetime.timedelta(days=100, microseconds=100)]], + [ + [datetime.timedelta(days=100, microseconds=100)], + [datetime.timedelta(days=100, microseconds=100)], + ], ), ], ) diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index c66006cc..10b861f1 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -8,7 +8,7 @@ use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter}, + value_converter::{convert_parameters_and_qs, postgres_to_py, PythonDTO, QueryParameter}, }; #[allow(clippy::module_name_repetitions)] @@ -82,7 +82,7 @@ impl PsqlpyConnection { } } - pub async fn execute( + pub async fn cursor_execute( &self, querystring: String, parameters: Option>, @@ -90,10 +90,7 @@ impl PsqlpyConnection { ) -> RustPSQLDriverPyResult { let prepared = prepared.unwrap_or(true); - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } + let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; let boxed_params = ¶ms .iter() @@ -103,7 +100,7 @@ impl PsqlpyConnection { let result = if prepared { self.query( - &self.prepare_cached(&querystring).await.map_err(|err| { + &self.prepare_cached(&qs).await.map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) @@ -117,13 +114,53 @@ impl PsqlpyConnection { )) })? } else { - self.query(&querystring, boxed_params) - .await - .map_err(|err| { + self.query(&qs, boxed_params).await.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? + }; + + Ok(PSQLDriverPyQueryResult::new(result)) + } + + pub async fn execute( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> RustPSQLDriverPyResult { + let prepared = prepared.unwrap_or(true); + + let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + + let boxed_params = ¶ms + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice(); + + let result = if prepared { + self.query( + &self.prepare_cached(&qs).await.map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" + "Cannot prepare statement, error - {err}" )) - })? + })?, + boxed_params, + ) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? + } else { + self.query(&qs, boxed_params).await.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? }; Ok(PSQLDriverPyQueryResult::new(result)) @@ -131,7 +168,7 @@ impl PsqlpyConnection { pub async fn execute_many( &self, - querystring: String, + mut querystring: String, parameters: Option>>, prepared: Option, ) -> RustPSQLDriverPyResult<()> { @@ -140,7 +177,11 @@ impl PsqlpyConnection { let mut params: Vec> = vec![]; if let Some(parameters) = parameters { for vec_of_py_any in parameters { - params.push(convert_parameters(vec_of_py_any)?); + // TODO: Fix multiple qs creation + let (qs, parsed_params) = + convert_parameters_and_qs(querystring.clone(), Some(vec_of_py_any))?; + querystring = qs; + params.push(parsed_params); } } @@ -182,10 +223,7 @@ impl PsqlpyConnection { ) -> RustPSQLDriverPyResult { let prepared = prepared.unwrap_or(true); - let mut params: Vec = vec![]; - if let Some(parameters) = parameters { - params = convert_parameters(parameters)?; - } + let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; let boxed_params = ¶ms .iter() @@ -195,7 +233,7 @@ impl PsqlpyConnection { let result = if prepared { self.query_one( - &self.prepare_cached(&querystring).await.map_err(|err| { + &self.prepare_cached(&qs).await.map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) @@ -209,13 +247,11 @@ impl PsqlpyConnection { )) })? } else { - self.query_one(&querystring, boxed_params) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + self.query_one(&qs, boxed_params).await.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute statement, error - {err}" + )) + })? }; return Ok(result); diff --git a/src/value_converter.rs b/src/value_converter.rs index fd9fca37..f3a95297 100644 --- a/src/value_converter.rs +++ b/src/value_converter.rs @@ -3,11 +3,12 @@ use chrono_tz::Tz; use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect}; use itertools::Itertools; use macaddr::{MacAddr6, MacAddr8}; +use once_cell::sync::Lazy; use pg_interval::Interval; use postgres_types::{Field, FromSql, Kind, ToSql}; use rust_decimal::Decimal; use serde_json::{json, Map, Value}; -use std::{fmt::Debug, net::IpAddr}; +use std::{collections::HashMap, fmt::Debug, net::IpAddr, sync::RwLock}; use uuid::Uuid; use bytes::{BufMut, BytesMut}; @@ -16,8 +17,8 @@ use pyo3::{ sync::GILOnceCell, types::{ PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyDictMethods, PyFloat, - PyInt, PyList, PyListMethods, PySequence, PySet, PyString, PyTime, PyTuple, PyType, - PyTypeMethods, + PyInt, PyList, PyListMethods, PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, + PyType, PyTypeMethods, }, Bound, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject, }; @@ -39,16 +40,15 @@ use postgres_array::{array::Array, Dimension}; static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new(); static TIMEDELTA_CLS: GILOnceCell> = GILOnceCell::new(); +static KWARGS_QUERYSTRINGS: Lazy)>>> = + Lazy::new(|| RwLock::new(Default::default())); pub type QueryParameter = (dyn ToSql + Sync); fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> { DECIMAL_CLS .get_or_try_init(py, || { - let type_object = py - .import_bound("decimal")? - .getattr("Decimal")? - .downcast_into()?; + let type_object = py.import("decimal")?.getattr("Decimal")?.downcast_into()?; Ok(type_object.unbind()) }) .map(|ty| ty.bind(py)) @@ -58,7 +58,7 @@ fn get_timedelta_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> { TIMEDELTA_CLS .get_or_try_init(py, || { let type_object = py - .import_bound("datetime")? + .import("datetime")? .getattr("timedelta")? .downcast_into()?; Ok(type_object.unbind()) @@ -613,6 +613,73 @@ impl ToSql for PythonDTO { to_sql_checked!(); } +fn parse_kwargs_qs(querystring: &str) -> (String, Vec) { + let re = regex::Regex::new(r"\$\(([^)]+)\)p").unwrap(); + + { + let kq_read = KWARGS_QUERYSTRINGS.read().unwrap(); + let qs = kq_read.get(querystring); + + if let Some(qs) = qs { + return qs.clone(); + } + }; + + let mut counter = 0; + let mut sequence = Vec::new(); + + let result = re.replace_all(querystring, |caps: ®ex::Captures| { + let account_id = caps[1].to_string(); + + sequence.push(account_id.clone()); + counter += 1; + + format!("${}", &counter) + }); + + let mut kq_write = KWARGS_QUERYSTRINGS.write().unwrap(); + kq_write.insert( + querystring.to_string(), + (result.clone().into(), sequence.clone()), + ); + (result.into(), sequence) +} + +pub fn convert_kwargs_parameters<'a>( + kw_params: &Bound<'_, PyMapping>, + querystring: &'a str, +) -> RustPSQLDriverPyResult<(String, Vec)> { + let mut result_vec: Vec = vec![]; + let (changed_string, params_names) = parse_kwargs_qs(querystring); + + for param_name in params_names { + match kw_params.get_item(¶m_name) { + Ok(param) => result_vec.push(py_to_rust(¶m)?), + Err(_) => { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + format!("Cannot find parameter with name <{param_name}> in parameters").into(), + )) + } + } + } + + Ok((changed_string, result_vec)) +} + +pub fn convert_seq_parameters( + seq_params: Vec>, +) -> RustPSQLDriverPyResult> { + let mut result_vec: Vec = vec![]; + Python::with_gil(|gil| { + for parameter in seq_params { + result_vec.push(py_to_rust(parameter.bind(gil))?); + } + Ok::<(), RustPSQLDriverError>(()) + })?; + + Ok(result_vec) +} + /// Convert parameters come from python. /// /// Parameters for `execute()` method can be either @@ -625,22 +692,36 @@ impl ToSql for PythonDTO { /// /// May return Err Result if can't convert python object. #[allow(clippy::needless_pass_by_value)] -pub fn convert_parameters(parameters: Py) -> RustPSQLDriverPyResult> { - let mut result_vec: Vec = vec![]; - Python::with_gil(|gil| { +pub fn convert_parameters_and_qs( + querystring: String, + parameters: Option>, +) -> RustPSQLDriverPyResult<(String, Vec)> { + let Some(parameters) = parameters else { + return Ok((querystring, vec![])); + }; + + let res = Python::with_gil(|gil| { let params = parameters.extract::>>(gil).map_err(|_| { RustPSQLDriverError::PyToRustValueConversionError( "Cannot convert you parameters argument into Rust type, please use List/Tuple" .into(), ) - })?; - for parameter in params { - result_vec.push(py_to_rust(parameter.bind(gil))?); + }); + if let Ok(params) = params { + return Ok((querystring, convert_seq_parameters(params)?)); } - Ok::<(), RustPSQLDriverError>(()) + + let kw_params = parameters.downcast_bound::(gil); + if let Ok(kw_params) = kw_params { + return convert_kwargs_parameters(kw_params, &querystring); + } + + Err(RustPSQLDriverError::PyToRustValueConversionError( + "Parameters must be sequence or mapping".into(), + )) })?; - Ok(result_vec) + Ok(res) } /// Convert Sequence from Python (except String) into flat vec.