diff --git a/Cargo.lock b/Cargo.lock
index 85f93be3..fee82b45 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -707,9 +707,9 @@ dependencies = [
[[package]]
name = "once_cell"
-version = "1.19.0"
+version = "1.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
+checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]]
name = "openssl"
@@ -997,7 +997,7 @@ dependencies = [
[[package]]
name = "psqlpy"
-version = "0.9.2"
+version = "0.9.3"
dependencies = [
"byteorder",
"bytes",
@@ -1010,6 +1010,7 @@ dependencies = [
"geo-types",
"itertools",
"macaddr",
+ "once_cell",
"openssl",
"openssl-src",
"openssl-sys",
@@ -1021,6 +1022,7 @@ dependencies = [
"postgres_array",
"pyo3",
"pyo3-async-runtimes",
+ "regex",
"rust_decimal 1.36.0",
"serde",
"serde_json",
@@ -1192,9 +1194,9 @@ dependencies = [
[[package]]
name = "regex"
-version = "1.10.6"
+version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
+checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
@@ -1204,9 +1206,9 @@ dependencies = [
[[package]]
name = "regex-automata"
-version = "0.4.7"
+version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
+checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
@@ -1215,9 +1217,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
-version = "0.8.4"
+version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
+checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rend"
diff --git a/Cargo.toml b/Cargo.toml
index 710c59cf..1846f8c2 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "psqlpy"
-version = "0.9.2"
+version = "0.9.3"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -54,3 +54,5 @@ pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch =
] }
futures-channel = "0.3.31"
futures = "0.3.31"
+regex = "1.11.1"
+once_cell = "1.20.3"
diff --git a/docs/.vuepress/sidebar.ts b/docs/.vuepress/sidebar.ts
index 94b082f9..d3afac19 100644
--- a/docs/.vuepress/sidebar.ts
+++ b/docs/.vuepress/sidebar.ts
@@ -33,6 +33,7 @@ export default sidebar({
prefix: "usage/",
collapsible: true,
children: [
+ "parameters",
{
text: "Types",
prefix: "types/",
diff --git a/docs/usage/parameters.md b/docs/usage/parameters.md
new file mode 100644
index 00000000..596fbbc5
--- /dev/null
+++ b/docs/usage/parameters.md
@@ -0,0 +1,42 @@
+---
+title: Passing parameters to SQL queries
+---
+
+We support two variant of passing parameters to sql queries.
+
+::: tabs
+@tab Parameters sequence
+
+You can pass parameters as some python Sequence.
+
+Placeholders in querystring must be marked as `$1`, `$2` and so on,
+depending on how many parameters you have.
+
+```python
+async def main():
+ ...
+
+ await connection.execute(
+ querystring="SELECT * FROM users WHERE id = $1",
+ parameters=(101,),
+ )
+```
+
+@tab Parameters mapping
+
+If you prefer use named arguments, we support it too.
+Placeholder in querystring must look like `$(parameter)p`.
+
+If you don't pass parameter but have it in querystring, exception will be raised.
+
+```python
+async def main():
+ ...
+
+ await connection.execute(
+ querystring="SELECT * FROM users WHERE id = $(user_id)p",
+ parameters=dict(user_id=101),
+ )
+```
+
+:::
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index bc906612..84c00f42 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -64,7 +64,7 @@ warn_unused_ignores = false
[tool.ruff]
fix = true
unsafe-fixes = true
-line-length = 120
+line-length = 89
exclude = [".venv/", "psqlpy-stress"]
[tool.ruff.format]
diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi
index 77ec440d..8c391d96 100644
--- a/python/psqlpy/_internal/__init__.pyi
+++ b/python/psqlpy/_internal/__init__.pyi
@@ -2,9 +2,9 @@ import types
from enum import Enum
from io import BytesIO
from ipaddress import IPv4Address, IPv6Address
-from typing import Any, Awaitable, Callable, Sequence, TypeVar
+from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar
-from typing_extensions import Buffer, Self
+from typing_extensions import Buffer, Self, TypeAlias
_CustomClass = TypeVar(
"_CustomClass",
@@ -13,6 +13,8 @@ _RowFactoryRV = TypeVar(
"_RowFactoryRV",
)
+ParamsT: TypeAlias = Sequence[Any] | Mapping[str, Any] | None
+
class QueryResult:
"""Result."""
@@ -150,7 +152,7 @@ class SingleQueryResult:
class SynchronousCommit(Enum):
"""
- Class for synchronous_commit option for transactions.
+ Synchronous_commit option for transactions.
### Variants:
- `On`: The meaning may change based on whether you have
@@ -181,7 +183,7 @@ class SynchronousCommit(Enum):
RemoteApply = 5
class IsolationLevel(Enum):
- """Class for Isolation Level for transactions."""
+ """Isolation Level for transactions."""
ReadUncommitted = 1
ReadCommitted = 2
@@ -290,7 +292,7 @@ class Cursor:
cursor_name: str
querystring: str
- parameters: Sequence[Any]
+ parameters: ParamsT = None
prepared: bool | None
conn_dbname: str | None
user: str | None
@@ -464,7 +466,7 @@ class Transaction:
async def execute(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> QueryResult:
"""Execute the query.
@@ -554,7 +556,7 @@ class Transaction:
async def fetch(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> QueryResult:
"""Fetch the result from database.
@@ -574,7 +576,7 @@ class Transaction:
async def fetch_row(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> SingleQueryResult:
"""Fetch exaclty single row from query.
@@ -613,7 +615,7 @@ class Transaction:
async def fetch_val(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> Any | None:
"""Execute the query and return first value of the first row.
@@ -814,7 +816,7 @@ class Transaction:
def cursor(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
fetch_number: int | None = None,
scroll: bool | None = None,
prepared: bool = True,
@@ -906,7 +908,7 @@ class Connection:
async def execute(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> QueryResult:
"""Execute the query.
@@ -990,7 +992,7 @@ class Connection:
async def fetch(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> QueryResult:
"""Fetch the result from database.
@@ -1010,7 +1012,7 @@ class Connection:
async def fetch_row(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> SingleQueryResult:
"""Fetch exaclty single row from query.
@@ -1046,7 +1048,7 @@ class Connection:
async def fetch_val(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
prepared: bool = True,
) -> Any:
"""Execute the query and return first value of the first row.
@@ -1100,7 +1102,7 @@ class Connection:
def cursor(
self: Self,
querystring: str,
- parameters: Sequence[Any] | None = None,
+ parameters: ParamsT = None,
fetch_number: int | None = None,
scroll: bool | None = None,
prepared: bool = True,
@@ -1708,10 +1710,13 @@ class ConnectionPoolBuilder:
self: Self,
keepalives_retries: int,
) -> Self:
- """
- Set the maximum number of TCP keepalive probes that will be sent before dropping a connection.
+ """Keepalives Retries.
+
+ Set the maximum number of TCP keepalive probes
+ that will be sent before dropping a connection.
- This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
+ This is ignored for Unix domain sockets,
+ or if the `keepalives` option is disabled.
### Parameters:
- `keepalives_retries`: number of retries.
diff --git a/python/psqlpy/_internal/extra_types.pyi b/python/psqlpy/_internal/extra_types.pyi
index 93037639..e29c7573 100644
--- a/python/psqlpy/_internal/extra_types.pyi
+++ b/python/psqlpy/_internal/extra_types.pyi
@@ -143,9 +143,13 @@ class MacAddr8:
class CustomType:
def __init__(self, value: bytes) -> None: ...
-Coordinates: TypeAlias = list[int | float] | set[int | float] | tuple[int | float, int | float]
+Coordinates: TypeAlias = (
+ list[int | float] | set[int | float] | tuple[int | float, int | float]
+)
PairsOfCoordinates: TypeAlias = (
- list[Coordinates | int | float] | set[Coordinates | int | float] | tuple[Coordinates | int | float, ...]
+ list[Coordinates | int | float]
+ | set[Coordinates | int | float]
+ | tuple[Coordinates | int | float, ...]
)
class Point:
@@ -227,7 +231,9 @@ class Circle:
def __init__(
self: Self,
- value: list[int | float] | set[int | float] | tuple[int | float, int | float, int | float],
+ value: list[int | float]
+ | set[int | float]
+ | tuple[int | float, int | float, int | float],
) -> None:
"""Create new instance of Circle.
@@ -374,7 +380,11 @@ class IpAddressArray:
def __init__(
self: Self,
inner: typing.Sequence[
- IPv4Address | IPv6Address | typing.Sequence[IPv4Address] | typing.Sequence[IPv6Address] | typing.Any,
+ IPv4Address
+ | IPv6Address
+ | typing.Sequence[IPv4Address]
+ | typing.Sequence[IPv6Address]
+ | typing.Any,
],
) -> None:
"""Create new instance of IpAddressArray.
diff --git a/python/tests/conftest.py b/python/tests/conftest.py
index 4a388f62..30426e5f 100644
--- a/python/tests/conftest.py
+++ b/python/tests/conftest.py
@@ -73,6 +73,11 @@ def listener_table_name() -> str:
return random_string()
+@pytest.fixture
+def map_parameters_table_name() -> str:
+ return random_string()
+
+
@pytest.fixture
def number_database_records() -> int:
return random.randint(10, 35)
@@ -161,6 +166,23 @@ async def create_table_for_listener_tests(
)
+@pytest.fixture
+async def create_table_for_map_parameters_test(
+ psql_pool: ConnectionPool,
+ map_parameters_table_name: str,
+) -> AsyncGenerator[None, None]:
+ connection = await psql_pool.connection()
+ await connection.execute(
+ f"CREATE TABLE {map_parameters_table_name}"
+ "(id SERIAL, name VARCHAR(255),surname VARCHAR(255), age INT)",
+ )
+
+ yield
+ await connection.execute(
+ f"DROP TABLE {map_parameters_table_name}",
+ )
+
+
@pytest.fixture
async def test_cursor(
psql_pool: ConnectionPool,
diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py
index 898cc405..7af208f2 100644
--- a/python/tests/test_connection.py
+++ b/python/tests/test_connection.py
@@ -183,7 +183,10 @@ async def test_execute_batch_method(psql_pool: ConnectionPool) -> None:
connection = await psql_pool.connection()
await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch")
await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch2")
- query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);"
+ query = (
+ "CREATE TABLE execute_batch (name VARCHAR);"
+ "CREATE TABLE execute_batch2 (name VARCHAR);"
+ )
async with psql_pool.acquire() as conn:
await conn.execute_batch(querystring=query)
await conn.execute(querystring="SELECT * FROM execute_batch")
diff --git a/python/tests/test_kwargs_parameters.py b/python/tests/test_kwargs_parameters.py
new file mode 100644
index 00000000..d1fb1ebf
--- /dev/null
+++ b/python/tests/test_kwargs_parameters.py
@@ -0,0 +1,73 @@
+import pytest
+from psqlpy import ConnectionPool
+from psqlpy.exceptions import (
+ PyToRustValueMappingError,
+)
+
+pytestmark = pytest.mark.anyio
+
+
+async def test_success_default_map_parameters(
+ psql_pool: ConnectionPool,
+ table_name: str,
+) -> None:
+ async with psql_pool.acquire() as conn:
+ exist_records = await conn.execute(
+ f"SELECT * FROM {table_name}",
+ )
+ result = exist_records.result()
+
+ test_fetch = await conn.execute(
+ f"SELECT * FROM {table_name} WHERE id = $(id_)p",
+ parameters={
+ "id_": result[0]["id"],
+ },
+ )
+
+ assert test_fetch.result()[0]["id"] == result[0]["id"]
+ assert test_fetch.result()[0]["name"] == result[0]["name"]
+
+
+@pytest.mark.usefixtures("create_table_for_map_parameters_test")
+async def test_success_multiple_same_parameters(
+ psql_pool: ConnectionPool,
+ map_parameters_table_name: str,
+) -> None:
+ test_name_surname = "Surname"
+ test_age = 1
+ async with psql_pool.acquire() as conn:
+ await conn.execute(
+ querystring=(
+ f"INSERT INTO {map_parameters_table_name} "
+ "(name, surname, age) VALUES ($(name)p, $(surname)p, $(age)p)"
+ ),
+ parameters={
+ "name": test_name_surname,
+ "surname": test_name_surname,
+ "age": test_age,
+ },
+ )
+
+ res = await conn.execute(
+ querystring=(
+ f"SELECT * FROM {map_parameters_table_name} "
+ "WHERE name = $(name)p OR surname = $(name)p"
+ ),
+ parameters={"name": test_name_surname},
+ )
+
+ assert res.result()[0]["name"] == test_name_surname
+ assert res.result()[0]["surname"] == test_name_surname
+ assert res.result()[0]["age"] == test_age
+
+
+async def test_failed_no_parameter(
+ psql_pool: ConnectionPool,
+ table_name: str,
+) -> None:
+ async with psql_pool.acquire() as conn:
+ with pytest.raises(expected_exception=PyToRustValueMappingError):
+ await conn.execute(
+ querystring=(f"SELECT * FROM {table_name} " "WHERE name = $(name)p"), # noqa: ISC001
+ parameters={"mistake": "wow"},
+ )
diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py
index c48c8974..a1ff0742 100644
--- a/python/tests/test_listener.py
+++ b/python/tests/test_listener.py
@@ -126,7 +126,7 @@ async def test_listener_listen(
listener_table_name=listener_table_name,
)
- listener.abort_listen()
+ await listener.shutdown()
@pytest.mark.usefixtures("create_table_for_listener_tests")
@@ -152,7 +152,7 @@ async def test_listener_asynciterator(
assert listener_msg.payload == TEST_PAYLOAD
break
- listener.abort_listen()
+ await listener.shutdown()
@pytest.mark.usefixtures("create_table_for_listener_tests")
@@ -175,7 +175,7 @@ async def test_listener_abort(
listener_table_name=listener_table_name,
)
- listener.abort_listen()
+ await listener.shutdown()
await clear_test_table(
psql_pool=psql_pool,
@@ -261,7 +261,7 @@ async def test_listener_more_than_one_callback(
assert data_result["channel"] == additional_channel
- listener.abort_listen()
+ await listener.shutdown()
@pytest.mark.usefixtures("create_table_for_listener_tests")
@@ -290,7 +290,7 @@ async def test_listener_clear_callbacks(
is_insert_exist=False,
)
- listener.abort_listen()
+ await listener.shutdown()
@pytest.mark.usefixtures("create_table_for_listener_tests")
@@ -317,4 +317,4 @@ async def test_listener_clear_all_callbacks(
is_insert_exist=False,
)
- listener.abort_listen()
+ await listener.shutdown()
diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py
index 151f5bb5..3c60676a 100644
--- a/python/tests/test_transaction.py
+++ b/python/tests/test_transaction.py
@@ -352,7 +352,10 @@ async def test_execute_batch_method(psql_pool: ConnectionPool) -> None:
connection = await psql_pool.connection()
await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch")
await connection.execute(querystring="DROP TABLE IF EXISTS execute_batch2")
- query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);"
+ query = (
+ "CREATE TABLE execute_batch (name VARCHAR);"
+ "CREATE TABLE execute_batch2 (name VARCHAR);"
+ )
async with connection.transaction() as transaction:
await transaction.execute_batch(querystring=query)
await transaction.execute(querystring="SELECT * FROM execute_batch")
@@ -377,7 +380,9 @@ async def test_synchronous_commit(
table_name: str,
number_database_records: int,
) -> None:
- async with psql_pool.acquire() as conn, conn.transaction(synchronous_commit=synchronous_commit) as trans:
+ async with psql_pool.acquire() as conn, conn.transaction(
+ synchronous_commit=synchronous_commit,
+ ) as trans:
res = await trans.execute(
f"SELECT * FROM {table_name}",
)
diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py
index de62c554..34361b22 100644
--- a/python/tests/test_value_converter.py
+++ b/python/tests/test_value_converter.py
@@ -134,7 +134,7 @@ async def test_as_class(
("TEXT", "Some String", "Some String"),
(
"XML",
- """Manual...""",
+ """Manual...""", # noqa: E501
"""Manual...""",
),
("BOOL", True, True),
@@ -151,7 +151,11 @@ async def test_as_class(
("TIME", now_datetime.time(), now_datetime.time()),
("TIMESTAMP", now_datetime, now_datetime),
("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz),
- ("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta),
+ (
+ "TIMESTAMPTZ",
+ now_datetime_with_tz_in_asia_jakarta,
+ now_datetime_with_tz_in_asia_jakarta,
+ ),
("UUID", uuid_, str(uuid_)),
("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")),
(
@@ -246,7 +250,11 @@ async def test_as_class(
("INT2 ARRAY", [SmallInt(12), SmallInt(100)], [12, 100]),
("INT2 ARRAY", [[SmallInt(12)], [SmallInt(100)]], [[12], [100]]),
("INT4 ARRAY", [Integer(121231231), Integer(121231231)], [121231231, 121231231]),
- ("INT4 ARRAY", [[Integer(121231231)], [Integer(121231231)]], [[121231231], [121231231]]),
+ (
+ "INT4 ARRAY",
+ [[Integer(121231231)], [Integer(121231231)]],
+ [[121231231], [121231231]],
+ ),
(
"INT8 ARRAY",
[BigInt(99999999999999999), BigInt(99999999999999999)],
@@ -308,7 +316,11 @@ async def test_as_class(
[[now_datetime.time()], [now_datetime.time()]],
),
("TIMESTAMP ARRAY", [now_datetime, now_datetime], [now_datetime, now_datetime]),
- ("TIMESTAMP ARRAY", [[now_datetime], [now_datetime]], [[now_datetime], [now_datetime]]),
+ (
+ "TIMESTAMP ARRAY",
+ [[now_datetime], [now_datetime]],
+ [[now_datetime], [now_datetime]],
+ ),
(
"TIMESTAMPTZ ARRAY",
[now_datetime_with_tz, now_datetime_with_tz],
@@ -638,8 +650,14 @@ async def test_as_class(
),
(
"INTERVAL ARRAY",
- [datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100)],
- [datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100)],
+ [
+ datetime.timedelta(days=100, microseconds=100),
+ datetime.timedelta(days=100, microseconds=100),
+ ],
+ [
+ datetime.timedelta(days=100, microseconds=100),
+ datetime.timedelta(days=100, microseconds=100),
+ ],
),
],
)
@@ -681,7 +699,9 @@ async def test_deserialization_composite_into_python(
await connection.execute("DROP TYPE IF EXISTS inner_type")
await connection.execute("DROP TYPE IF EXISTS enum_type")
await connection.execute("CREATE TYPE enum_type AS ENUM ('sad', 'ok', 'happy')")
- await connection.execute("CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)")
+ await connection.execute(
+ "CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)",
+ )
create_type_query = """
CREATE type all_types AS (
bytea_ BYTEA,
@@ -1082,7 +1102,12 @@ async def test_empty_array(
async with psql_pool.acquire() as conn:
await conn.execute("DROP TABLE IF EXISTS test_earr")
await conn.execute(
- "CREATE TABLE test_earr (id serial NOT NULL PRIMARY KEY, e_array text[] NOT NULL DEFAULT array[]::text[])",
+ """
+ CREATE TABLE test_earr (
+ id serial NOT NULL PRIMARY KEY,
+ e_array text[] NOT NULL DEFAULT array[]::text[]
+ )
+ """,
)
await conn.execute("INSERT INTO test_earr(id) VALUES(2);")
@@ -1125,8 +1150,16 @@ async def test_empty_array(
("INT2 ARRAY", Int16Array([]), []),
("INT2 ARRAY", Int16Array([SmallInt(12), SmallInt(100)]), [12, 100]),
("INT2 ARRAY", Int16Array([[SmallInt(12)], [SmallInt(100)]]), [[12], [100]]),
- ("INT4 ARRAY", Int32Array([Integer(121231231), Integer(121231231)]), [121231231, 121231231]),
- ("INT4 ARRAY", Int32Array([[Integer(121231231)], [Integer(121231231)]]), [[121231231], [121231231]]),
+ (
+ "INT4 ARRAY",
+ Int32Array([Integer(121231231), Integer(121231231)]),
+ [121231231, 121231231],
+ ),
+ (
+ "INT4 ARRAY",
+ Int32Array([[Integer(121231231)], [Integer(121231231)]]),
+ [[121231231], [121231231]],
+ ),
(
"INT8 ARRAY",
Int64Array([BigInt(99999999999999999), BigInt(99999999999999999)]),
@@ -1187,8 +1220,16 @@ async def test_empty_array(
TimeArray([[now_datetime.time()], [now_datetime.time()]]),
[[now_datetime.time()], [now_datetime.time()]],
),
- ("TIMESTAMP ARRAY", DateTimeArray([now_datetime, now_datetime]), [now_datetime, now_datetime]),
- ("TIMESTAMP ARRAY", DateTimeArray([[now_datetime], [now_datetime]]), [[now_datetime], [now_datetime]]),
+ (
+ "TIMESTAMP ARRAY",
+ DateTimeArray([now_datetime, now_datetime]),
+ [now_datetime, now_datetime],
+ ),
+ (
+ "TIMESTAMP ARRAY",
+ DateTimeArray([[now_datetime], [now_datetime]]),
+ [[now_datetime], [now_datetime]],
+ ),
(
"TIMESTAMPTZ ARRAY",
DateTimeTZArray([now_datetime_with_tz, now_datetime_with_tz]),
@@ -1554,9 +1595,15 @@ async def test_empty_array(
(
"INTERVAL ARRAY",
IntervalArray(
- [[datetime.timedelta(days=100, microseconds=100)], [datetime.timedelta(days=100, microseconds=100)]],
+ [
+ [datetime.timedelta(days=100, microseconds=100)],
+ [datetime.timedelta(days=100, microseconds=100)],
+ ],
),
- [[datetime.timedelta(days=100, microseconds=100)], [datetime.timedelta(days=100, microseconds=100)]],
+ [
+ [datetime.timedelta(days=100, microseconds=100)],
+ [datetime.timedelta(days=100, microseconds=100)],
+ ],
),
],
)
diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs
index c66006cc..10b861f1 100644
--- a/src/driver/inner_connection.rs
+++ b/src/driver/inner_connection.rs
@@ -8,7 +8,7 @@ use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement};
use crate::{
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult},
- value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter},
+ value_converter::{convert_parameters_and_qs, postgres_to_py, PythonDTO, QueryParameter},
};
#[allow(clippy::module_name_repetitions)]
@@ -82,7 +82,7 @@ impl PsqlpyConnection {
}
}
- pub async fn execute(
+ pub async fn cursor_execute(
&self,
querystring: String,
parameters: Option>,
@@ -90,10 +90,7 @@ impl PsqlpyConnection {
) -> RustPSQLDriverPyResult {
let prepared = prepared.unwrap_or(true);
- let mut params: Vec = vec![];
- if let Some(parameters) = parameters {
- params = convert_parameters(parameters)?;
- }
+ let (qs, params) = convert_parameters_and_qs(querystring, parameters)?;
let boxed_params = ¶ms
.iter()
@@ -103,7 +100,7 @@ impl PsqlpyConnection {
let result = if prepared {
self.query(
- &self.prepare_cached(&querystring).await.map_err(|err| {
+ &self.prepare_cached(&qs).await.map_err(|err| {
RustPSQLDriverError::ConnectionExecuteError(format!(
"Cannot prepare statement, error - {err}"
))
@@ -117,13 +114,53 @@ impl PsqlpyConnection {
))
})?
} else {
- self.query(&querystring, boxed_params)
- .await
- .map_err(|err| {
+ self.query(&qs, boxed_params).await.map_err(|err| {
+ RustPSQLDriverError::ConnectionExecuteError(format!(
+ "Cannot execute statement, error - {err}"
+ ))
+ })?
+ };
+
+ Ok(PSQLDriverPyQueryResult::new(result))
+ }
+
+ pub async fn execute(
+ &self,
+ querystring: String,
+ parameters: Option>,
+ prepared: Option,
+ ) -> RustPSQLDriverPyResult {
+ let prepared = prepared.unwrap_or(true);
+
+ let (qs, params) = convert_parameters_and_qs(querystring, parameters)?;
+
+ let boxed_params = ¶ms
+ .iter()
+ .map(|param| param as &QueryParameter)
+ .collect::>()
+ .into_boxed_slice();
+
+ let result = if prepared {
+ self.query(
+ &self.prepare_cached(&qs).await.map_err(|err| {
RustPSQLDriverError::ConnectionExecuteError(format!(
- "Cannot execute statement, error - {err}"
+ "Cannot prepare statement, error - {err}"
))
- })?
+ })?,
+ boxed_params,
+ )
+ .await
+ .map_err(|err| {
+ RustPSQLDriverError::ConnectionExecuteError(format!(
+ "Cannot execute statement, error - {err}"
+ ))
+ })?
+ } else {
+ self.query(&qs, boxed_params).await.map_err(|err| {
+ RustPSQLDriverError::ConnectionExecuteError(format!(
+ "Cannot execute statement, error - {err}"
+ ))
+ })?
};
Ok(PSQLDriverPyQueryResult::new(result))
@@ -131,7 +168,7 @@ impl PsqlpyConnection {
pub async fn execute_many(
&self,
- querystring: String,
+ mut querystring: String,
parameters: Option>>,
prepared: Option,
) -> RustPSQLDriverPyResult<()> {
@@ -140,7 +177,11 @@ impl PsqlpyConnection {
let mut params: Vec> = vec![];
if let Some(parameters) = parameters {
for vec_of_py_any in parameters {
- params.push(convert_parameters(vec_of_py_any)?);
+ // TODO: Fix multiple qs creation
+ let (qs, parsed_params) =
+ convert_parameters_and_qs(querystring.clone(), Some(vec_of_py_any))?;
+ querystring = qs;
+ params.push(parsed_params);
}
}
@@ -182,10 +223,7 @@ impl PsqlpyConnection {
) -> RustPSQLDriverPyResult {
let prepared = prepared.unwrap_or(true);
- let mut params: Vec = vec![];
- if let Some(parameters) = parameters {
- params = convert_parameters(parameters)?;
- }
+ let (qs, params) = convert_parameters_and_qs(querystring, parameters)?;
let boxed_params = ¶ms
.iter()
@@ -195,7 +233,7 @@ impl PsqlpyConnection {
let result = if prepared {
self.query_one(
- &self.prepare_cached(&querystring).await.map_err(|err| {
+ &self.prepare_cached(&qs).await.map_err(|err| {
RustPSQLDriverError::ConnectionExecuteError(format!(
"Cannot prepare statement, error - {err}"
))
@@ -209,13 +247,11 @@ impl PsqlpyConnection {
))
})?
} else {
- self.query_one(&querystring, boxed_params)
- .await
- .map_err(|err| {
- RustPSQLDriverError::ConnectionExecuteError(format!(
- "Cannot execute statement, error - {err}"
- ))
- })?
+ self.query_one(&qs, boxed_params).await.map_err(|err| {
+ RustPSQLDriverError::ConnectionExecuteError(format!(
+ "Cannot execute statement, error - {err}"
+ ))
+ })?
};
return Ok(result);
diff --git a/src/value_converter.rs b/src/value_converter.rs
index fd9fca37..f3a95297 100644
--- a/src/value_converter.rs
+++ b/src/value_converter.rs
@@ -3,11 +3,12 @@ use chrono_tz::Tz;
use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect};
use itertools::Itertools;
use macaddr::{MacAddr6, MacAddr8};
+use once_cell::sync::Lazy;
use pg_interval::Interval;
use postgres_types::{Field, FromSql, Kind, ToSql};
use rust_decimal::Decimal;
use serde_json::{json, Map, Value};
-use std::{fmt::Debug, net::IpAddr};
+use std::{collections::HashMap, fmt::Debug, net::IpAddr, sync::RwLock};
use uuid::Uuid;
use bytes::{BufMut, BytesMut};
@@ -16,8 +17,8 @@ use pyo3::{
sync::GILOnceCell,
types::{
PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyDictMethods, PyFloat,
- PyInt, PyList, PyListMethods, PySequence, PySet, PyString, PyTime, PyTuple, PyType,
- PyTypeMethods,
+ PyInt, PyList, PyListMethods, PyMapping, PySequence, PySet, PyString, PyTime, PyTuple,
+ PyType, PyTypeMethods,
},
Bound, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
};
@@ -39,16 +40,15 @@ use postgres_array::{array::Array, Dimension};
static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new();
static TIMEDELTA_CLS: GILOnceCell> = GILOnceCell::new();
+static KWARGS_QUERYSTRINGS: Lazy)>>> =
+ Lazy::new(|| RwLock::new(Default::default()));
pub type QueryParameter = (dyn ToSql + Sync);
fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
DECIMAL_CLS
.get_or_try_init(py, || {
- let type_object = py
- .import_bound("decimal")?
- .getattr("Decimal")?
- .downcast_into()?;
+ let type_object = py.import("decimal")?.getattr("Decimal")?.downcast_into()?;
Ok(type_object.unbind())
})
.map(|ty| ty.bind(py))
@@ -58,7 +58,7 @@ fn get_timedelta_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
TIMEDELTA_CLS
.get_or_try_init(py, || {
let type_object = py
- .import_bound("datetime")?
+ .import("datetime")?
.getattr("timedelta")?
.downcast_into()?;
Ok(type_object.unbind())
@@ -613,6 +613,73 @@ impl ToSql for PythonDTO {
to_sql_checked!();
}
+fn parse_kwargs_qs(querystring: &str) -> (String, Vec) {
+ let re = regex::Regex::new(r"\$\(([^)]+)\)p").unwrap();
+
+ {
+ let kq_read = KWARGS_QUERYSTRINGS.read().unwrap();
+ let qs = kq_read.get(querystring);
+
+ if let Some(qs) = qs {
+ return qs.clone();
+ }
+ };
+
+ let mut counter = 0;
+ let mut sequence = Vec::new();
+
+ let result = re.replace_all(querystring, |caps: ®ex::Captures| {
+ let account_id = caps[1].to_string();
+
+ sequence.push(account_id.clone());
+ counter += 1;
+
+ format!("${}", &counter)
+ });
+
+ let mut kq_write = KWARGS_QUERYSTRINGS.write().unwrap();
+ kq_write.insert(
+ querystring.to_string(),
+ (result.clone().into(), sequence.clone()),
+ );
+ (result.into(), sequence)
+}
+
+pub fn convert_kwargs_parameters<'a>(
+ kw_params: &Bound<'_, PyMapping>,
+ querystring: &'a str,
+) -> RustPSQLDriverPyResult<(String, Vec)> {
+ let mut result_vec: Vec = vec![];
+ let (changed_string, params_names) = parse_kwargs_qs(querystring);
+
+ for param_name in params_names {
+ match kw_params.get_item(¶m_name) {
+ Ok(param) => result_vec.push(py_to_rust(¶m)?),
+ Err(_) => {
+ return Err(RustPSQLDriverError::PyToRustValueConversionError(
+ format!("Cannot find parameter with name <{param_name}> in parameters").into(),
+ ))
+ }
+ }
+ }
+
+ Ok((changed_string, result_vec))
+}
+
+pub fn convert_seq_parameters(
+ seq_params: Vec>,
+) -> RustPSQLDriverPyResult> {
+ let mut result_vec: Vec = vec![];
+ Python::with_gil(|gil| {
+ for parameter in seq_params {
+ result_vec.push(py_to_rust(parameter.bind(gil))?);
+ }
+ Ok::<(), RustPSQLDriverError>(())
+ })?;
+
+ Ok(result_vec)
+}
+
/// Convert parameters come from python.
///
/// Parameters for `execute()` method can be either
@@ -625,22 +692,36 @@ impl ToSql for PythonDTO {
///
/// May return Err Result if can't convert python object.
#[allow(clippy::needless_pass_by_value)]
-pub fn convert_parameters(parameters: Py) -> RustPSQLDriverPyResult> {
- let mut result_vec: Vec = vec![];
- Python::with_gil(|gil| {
+pub fn convert_parameters_and_qs(
+ querystring: String,
+ parameters: Option>,
+) -> RustPSQLDriverPyResult<(String, Vec)> {
+ let Some(parameters) = parameters else {
+ return Ok((querystring, vec![]));
+ };
+
+ let res = Python::with_gil(|gil| {
let params = parameters.extract::>>(gil).map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError(
"Cannot convert you parameters argument into Rust type, please use List/Tuple"
.into(),
)
- })?;
- for parameter in params {
- result_vec.push(py_to_rust(parameter.bind(gil))?);
+ });
+ if let Ok(params) = params {
+ return Ok((querystring, convert_seq_parameters(params)?));
}
- Ok::<(), RustPSQLDriverError>(())
+
+ let kw_params = parameters.downcast_bound::(gil);
+ if let Ok(kw_params) = kw_params {
+ return convert_kwargs_parameters(kw_params, &querystring);
+ }
+
+ Err(RustPSQLDriverError::PyToRustValueConversionError(
+ "Parameters must be sequence or mapping".into(),
+ ))
})?;
- Ok(result_vec)
+ Ok(res)
}
/// Convert Sequence from Python (except String) into flat vec.