diff --git a/docs/components/connection.md b/docs/components/connection.md index 7a470b06..1e82d99a 100644 --- a/docs/components/connection.md +++ b/docs/components/connection.md @@ -200,7 +200,7 @@ async def main() -> None: ) ``` -### Back To Pool +### Close Returns connection to the pool. It's crucial to commit all transactions and close all cursor which are made from the connection. Otherwise, this method won't do anything useful. @@ -213,5 +213,5 @@ There is no need in this method if you use async context manager. async def main() -> None: ... connection = await db_pool.connection() - connection.back_to_pool() + connection.close() ``` diff --git a/pyproject.toml b/pyproject.toml index cd2b2f42..c6f3b2e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,3 +106,6 @@ ignore = [ [tool.ruff.pydocstyle] convention = "pep257" ignore-decorators = ["typing.overload"] + +[project.entry-points."sqlalchemy.dialects"] +psqlpy = "psqlpy_sqlalchemy.dialect:PSQLPyAsyncDialect" diff --git a/python/psqlpy/__init__.py b/python/psqlpy/__init__.py index fbaf123d..7c76be33 100644 --- a/python/psqlpy/__init__.py +++ b/python/psqlpy/__init__.py @@ -18,6 +18,9 @@ connect, connect_pool, ) +from psqlpy.exceptions import ( + Error, +) __all__ = [ "ConnRecyclingMethod", @@ -25,6 +28,7 @@ "ConnectionPool", "ConnectionPoolBuilder", "Cursor", + "Error", "IsolationLevel", "KeepaliveConfig", "Listener", diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index d900d228..10ac499e 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -253,11 +253,12 @@ class KeepaliveConfig: """Initialize new config.""" class Cursor: - """Represent opened cursor in a transaction. + """Represent binary cursor in a transaction. It can be used as an asynchronous iterator. """ + array_size: int cursor_name: str querystring: str parameters: ParamsT = None @@ -282,118 +283,27 @@ class Cursor: Execute DECLARE command for the cursor. """ - async def close(self: Self) -> None: + def close(self: Self) -> None: """Close the cursor. Execute CLOSE command for the cursor. """ - async def fetch( - self: Self, - fetch_number: int | None = None, - ) -> QueryResult: - """Fetch next rows. - - By default fetches 10 next rows. - - ### Parameters: - - `fetch_number`: how many rows need to fetch. - - ### Returns: - result as `QueryResult`. - """ - async def fetch_next( - self: Self, - ) -> QueryResult: - """Fetch next row. - - Execute FETCH NEXT - - ### Returns: - result as `QueryResult`. - """ - async def fetch_prior( - self: Self, - ) -> QueryResult: - """Fetch previous row. - - Execute FETCH PRIOR - - ### Returns: - result as `QueryResult`. - """ - async def fetch_first( - self: Self, - ) -> QueryResult: - """Fetch first row. - - Execute FETCH FIRST - - ### Returns: - result as `QueryResult`. - """ - async def fetch_last( - self: Self, - ) -> QueryResult: - """Fetch last row. - - Execute FETCH LAST - - ### Returns: - result as `QueryResult`. - """ - async def fetch_absolute( - self: Self, - absolute_number: int, - ) -> QueryResult: - """Fetch absolute rows. - - Execute FETCH ABSOLUTE . - - ### Returns: - result as `QueryResult`. - """ - async def fetch_relative( - self: Self, - relative_number: int, - ) -> QueryResult: - """Fetch absolute rows. - - Execute FETCH RELATIVE . - - ### Returns: - result as `QueryResult`. - """ - async def fetch_forward_all( - self: Self, - ) -> QueryResult: - """Fetch forward all rows. - - Execute FETCH FORWARD ALL. - - ### Returns: - result as `QueryResult`. - """ - async def fetch_backward( - self: Self, - backward_count: int, - ) -> QueryResult: - """Fetch backward rows. - - Execute FETCH BACKWARD . - - ### Returns: - result as `QueryResult`. - """ - async def fetch_backward_all( + async def execute( self: Self, + querystring: str, + parameters: ParamsT = None, ) -> QueryResult: - """Fetch backward all rows. - - Execute FETCH BACKWARD ALL. + """Start cursor with querystring and parameters. - ### Returns: - result as `QueryResult`. + Method should be used instead of context manager + and `start` method. """ + async def fetchone(self: Self) -> QueryResult: + """Return next one row from the cursor.""" + async def fetchmany(self: Self, size: int | None = None) -> QueryResult: + """Return rows from the cursor.""" + async def fetchall(self: Self, size: int | None = None) -> QueryResult: + """Return all remaining rows from the cursor.""" class Transaction: """Single connection for executing queries. @@ -1098,8 +1008,6 @@ class Connection: querystring: str, parameters: ParamsT = None, fetch_number: int | None = None, - scroll: bool | None = None, - prepared: bool = True, ) -> Cursor: """Create new cursor object. @@ -1136,7 +1044,7 @@ class Connection: ... # do something with this result. ``` """ - def back_to_pool(self: Self) -> None: + def close(self: Self) -> None: """Return connection back to the pool. It necessary to commit all transactions and close all cursor diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 7af208f2..4643e0c6 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -145,7 +145,7 @@ async def test_connection_cursor( await transaction.begin() cursor = connection.cursor(querystring=f"SELECT * FROM {table_name}") await cursor.start() - await cursor.close() + cursor.close() await transaction.commit() @@ -172,7 +172,7 @@ async def test_closed_connection_error( ) -> None: """Test exception when connection is closed.""" connection = await psql_pool.connection() - connection.back_to_pool() + connection.close() with pytest.raises(expected_exception=ConnectionClosedError): await connection.execute("SELECT 1") diff --git a/python/tests/test_cursor.py b/python/tests/test_cursor.py index 07fca375..fdd53ba7 100644 --- a/python/tests/test_cursor.py +++ b/python/tests/test_cursor.py @@ -1,177 +1,68 @@ from __future__ import annotations -import math from typing import TYPE_CHECKING import pytest if TYPE_CHECKING: - from psqlpy import ConnectionPool, Cursor, QueryResult, Transaction + from psqlpy import ConnectionPool, Cursor pytestmark = pytest.mark.anyio -async def test_cursor_fetch( +async def test_cursor_fetchmany( number_database_records: int, test_cursor: Cursor, ) -> None: """Test cursor fetch with custom number of fetch.""" - result = await test_cursor.fetch(fetch_number=number_database_records // 2) + result = await test_cursor.fetchmany(size=number_database_records // 2) assert len(result.result()) == number_database_records // 2 -async def test_cursor_fetch_next( +async def test_cursor_fetchone( test_cursor: Cursor, ) -> None: - """Test cursor fetch next.""" - result = await test_cursor.fetch_next() + result = await test_cursor.fetchone() assert len(result.result()) == 1 -async def test_cursor_fetch_prior( - test_cursor: Cursor, -) -> None: - """Test cursor fetch prior.""" - result = await test_cursor.fetch_prior() - assert len(result.result()) == 0 - - await test_cursor.fetch(fetch_number=2) - result = await test_cursor.fetch_prior() - assert len(result.result()) == 1 - - -async def test_cursor_fetch_first( - test_cursor: Cursor, -) -> None: - """Test cursor fetch first.""" - fetch_first = await test_cursor.fetch(fetch_number=1) - - await test_cursor.fetch(fetch_number=3) - - first = await test_cursor.fetch_first() - - assert fetch_first.result() == first.result() - - -async def test_cursor_fetch_last( - test_cursor: Cursor, +async def test_cursor_fetchall( number_database_records: int, -) -> None: - """Test cursor fetch last.""" - all_res = await test_cursor.fetch( - fetch_number=number_database_records, - ) - - last_res = await test_cursor.fetch_last() - - assert all_res.result()[-1] == last_res.result()[0] - - -async def test_cursor_fetch_absolute( - test_cursor: Cursor, - number_database_records: int, -) -> None: - """Test cursor fetch Absolute.""" - all_res = await test_cursor.fetch( - fetch_number=number_database_records, - ) - - first_record = await test_cursor.fetch_absolute( - absolute_number=1, - ) - last_record = await test_cursor.fetch_absolute( - absolute_number=-1, - ) - - assert all_res.result()[0] == first_record.result()[0] - assert all_res.result()[-1] == last_record.result()[0] - - -async def test_cursor_fetch_relative( test_cursor: Cursor, - number_database_records: int, ) -> None: - """Test cursor fetch Relative.""" - first_absolute = await test_cursor.fetch_relative( - relative_number=1, - ) - - assert first_absolute.result() - - await test_cursor.fetch( - fetch_number=number_database_records, - ) - records = await test_cursor.fetch_relative( - relative_number=1, - ) + result = await test_cursor.fetchall() + assert len(result.result()) == number_database_records - assert not (records.result()) - -async def test_cursor_fetch_forward_all( - test_cursor: Cursor, +async def test_cursor_start( + psql_pool: ConnectionPool, + table_name: str, number_database_records: int, ) -> None: - """Test that cursor execute FETCH FORWARD ALL correctly.""" - default_fetch_number = 2 - await test_cursor.fetch(fetch_number=default_fetch_number) - - rest_results = await test_cursor.fetch_forward_all() - - assert len(rest_results.result()) == number_database_records - default_fetch_number - - -async def test_cursor_fetch_backward( - test_cursor: Cursor, -) -> None: - """Test cursor backward fetch.""" - must_be_empty = await test_cursor.fetch_backward(backward_count=10) - assert not (must_be_empty.result()) - - default_fetch_number = 5 - await test_cursor.fetch(fetch_number=default_fetch_number) - - expected_number_of_results = 3 - must_not_be_empty = await test_cursor.fetch_backward( - backward_count=expected_number_of_results, + connection = await psql_pool.connection() + cursor = connection.cursor( + querystring=f"SELECT * FROM {table_name}", ) - assert len(must_not_be_empty.result()) == expected_number_of_results - - -async def test_cursor_fetch_backward_all( - test_cursor: Cursor, -) -> None: - """Test cursor `fetch_backward_all`.""" - must_be_empty = await test_cursor.fetch_backward_all() - assert not (must_be_empty.result()) + await cursor.start() + results = await cursor.fetchall() - default_fetch_number = 5 - await test_cursor.fetch(fetch_number=default_fetch_number) + assert len(results.result()) == number_database_records - must_not_be_empty = await test_cursor.fetch_backward_all() - assert len(must_not_be_empty.result()) == default_fetch_number - 1 + cursor.close() -async def test_cursor_as_async_manager( +async def test_cursor_as_async_context_manager( psql_pool: ConnectionPool, table_name: str, number_database_records: int, ) -> None: - """Test cursor async manager and async iterator.""" connection = await psql_pool.connection() - transaction: Transaction - cursor: Cursor - all_results: list[QueryResult] = [] - expected_num_results = math.ceil(number_database_records / 3) - fetch_number = 3 - async with connection.transaction() as transaction, transaction.cursor( + async with connection.cursor( querystring=f"SELECT * FROM {table_name}", - fetch_number=fetch_number, ) as cursor: - async for result in cursor: - all_results.append(result) # noqa: PERF401 + results = await cursor.fetchall() - assert len(all_results) == expected_num_results + assert len(results.result()) == number_database_records async def test_cursor_send_underlying_connection_to_pool( @@ -184,7 +75,7 @@ async def test_cursor_send_underlying_connection_to_pool( async with transaction.cursor( querystring=f"SELECT * FROM {table_name}", ) as cursor: - await cursor.fetch(10) + await cursor.fetchmany(10) assert not psql_pool.status().available assert not psql_pool.status().available assert not psql_pool.status().available @@ -200,9 +91,9 @@ async def test_cursor_send_underlying_connection_to_pool_manually( async with connection.transaction() as transaction: cursor = transaction.cursor(querystring=f"SELECT * FROM {table_name}") await cursor.start() - await cursor.fetch(10) + await cursor.fetchmany(10) assert not psql_pool.status().available - await cursor.close() + cursor.close() assert not psql_pool.status().available assert not psql_pool.status().available assert psql_pool.status().available == 1 diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py index a1ff0742..8db12ae9 100644 --- a/python/tests/test_listener.py +++ b/python/tests/test_listener.py @@ -64,7 +64,7 @@ async def notify( connection = await psql_pool.connection() await connection.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") - connection.back_to_pool() + connection.close() async def check_insert_callback( @@ -91,7 +91,7 @@ async def check_insert_callback( assert data_record["payload"] == TEST_PAYLOAD assert data_record["channel"] == TEST_CHANNEL - connection.back_to_pool() + connection.close() async def clear_test_table( @@ -102,7 +102,7 @@ async def clear_test_table( await connection.execute( f"DELETE FROM {listener_table_name}", ) - connection.back_to_pool() + connection.close() @pytest.mark.usefixtures("create_table_for_listener_tests") diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index a186084a..343bb868 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -180,7 +180,7 @@ async def test_transaction_rollback( f"INSERT INTO {table_name} VALUES ($1, $2)", parameters=[100, test_name], ) - connection.back_to_pool() + connection.close() assert not (result_from_conn.result()) @@ -358,4 +358,4 @@ async def test_execute_batch_method(psql_pool: ConnectionPool) -> None: await transaction.execute(querystring="SELECT * FROM execute_batch") await transaction.execute(querystring="SELECT * FROM execute_batch2") - connection.back_to_pool() + connection.close() diff --git a/src/connection/impls.rs b/src/connection/impls.rs index 84683edb..931770aa 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -1,11 +1,8 @@ -use std::sync::{Arc, RwLock}; - use bytes::Buf; use pyo3::{PyAny, Python}; use tokio_postgres::{CopyInSink, Portal as tp_Portal, Row, Statement, ToStatement}; use crate::{ - driver::portal::Portal, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, options::{IsolationLevel, ReadVariant}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, @@ -19,7 +16,7 @@ use tokio_postgres::Transaction as tp_Transaction; use super::{ structs::{PSQLPyConnection, PoolConnection, SingleConnection}, - traits::{CloseTransaction, Connection, Cursor, StartTransaction, Transaction}, + traits::{CloseTransaction, Connection, StartTransaction, Transaction}, }; impl Transaction for T @@ -327,52 +324,42 @@ impl CloseTransaction for PSQLPyConnection { } } -impl Cursor for PSQLPyConnection { - async fn start_cursor( - &mut self, - cursor_name: &str, - scroll: &Option, +impl PSQLPyConnection { + pub fn in_transaction(&self) -> bool { + match self { + PSQLPyConnection::PoolConn(conn) => conn.in_transaction, + PSQLPyConnection::SingleConnection(conn) => conn.in_transaction, + } + } + + pub async fn prepare_statement( + &self, querystring: String, - prepared: &Option, parameters: Option>, - ) -> PSQLPyResult<()> { - let cursor_qs = self.build_cursor_start_qs(cursor_name, scroll, &querystring); - self.execute(cursor_qs, parameters, *prepared) + ) -> PSQLPyResult { + StatementBuilder::new(&querystring, ¶meters, self, Some(true)) + .build() .await - .map_err(|err| { - RustPSQLDriverError::CursorStartError(format!("Cannot start cursor due to {err}")) - })?; - match self { - PSQLPyConnection::PoolConn(conn) => conn.in_cursor = true, - PSQLPyConnection::SingleConnection(conn) => conn.in_cursor = true, - } - Ok(()) } - async fn close_cursor(&mut self, cursor_name: &str) -> PSQLPyResult<()> { - self.execute( - format!("CLOSE {cursor_name}"), - Option::default(), - Some(false), - ) - .await?; + pub async fn execute_statement( + &self, + statement: &PsqlpyStatement, + ) -> PSQLPyResult { + let result = self + .query(statement.statement_query()?, &statement.params()) + .await?; - match self { - PSQLPyConnection::PoolConn(conn) => conn.in_cursor = false, - PSQLPyConnection::SingleConnection(conn) => conn.in_cursor = false, - } - Ok(()) + Ok(PSQLDriverPyQueryResult::new(result)) } -} -impl PSQLPyConnection { pub async fn execute( &self, querystring: String, parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let statement = StatementBuilder::new(querystring, parameters, self, prepared) + let statement = StatementBuilder::new(&querystring, ¶meters, self, prepared) .build() .await?; @@ -408,7 +395,7 @@ impl PSQLPyConnection { for vec_of_py_any in parameters { // TODO: Fix multiple qs creation let statement = - StatementBuilder::new(querystring.clone(), Some(vec_of_py_any), self, prepared) + StatementBuilder::new(&querystring, &Some(vec_of_py_any), self, prepared) .build() .await?; @@ -451,7 +438,7 @@ impl PSQLPyConnection { parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let statement = StatementBuilder::new(querystring, parameters, self, prepared) + let statement = StatementBuilder::new(&querystring, ¶meters, self, prepared) .build() .await?; @@ -547,12 +534,26 @@ impl PSQLPyConnection { pub async fn portal( &mut self, - querystring: String, - parameters: Option>, + querystring: Option<&String>, + parameters: &Option>, + statement: Option<&PsqlpyStatement>, ) -> PSQLPyResult<(PSQLPyTransaction, tp_Portal)> { - let statement = StatementBuilder::new(querystring, parameters, self, Some(false)) - .build() - .await?; + let statement = { + match statement { + Some(stmt) => stmt, + None => { + let Some(querystring) = querystring else { + return Err(RustPSQLDriverError::ConnectionExecuteError( + "Can't create cursor without querystring".into(), + )); + }; + + &StatementBuilder::new(querystring, parameters, self, Some(false)) + .build() + .await? + } + } + }; let transaction = self.transaction().await?; let inner_portal = transaction diff --git a/src/connection/structs.rs b/src/connection/structs.rs index ccfd101d..a50d3d69 100644 --- a/src/connection/structs.rs +++ b/src/connection/structs.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use deadpool_postgres::Object; use tokio_postgres::{Client, Config}; +#[derive(Debug)] pub struct PoolConnection { pub connection: Object, pub in_transaction: bool, @@ -20,6 +21,8 @@ impl PoolConnection { } } } + +#[derive(Debug)] pub struct SingleConnection { pub connection: Client, pub in_transaction: bool, @@ -38,6 +41,7 @@ impl SingleConnection { } } +#[derive(Debug)] pub enum PSQLPyConnection { PoolConn(PoolConnection), SingleConnection(SingleConnection), diff --git a/src/connection/traits.rs b/src/connection/traits.rs index 8e868a06..5d8d49ae 100644 --- a/src/connection/traits.rs +++ b/src/connection/traits.rs @@ -102,38 +102,38 @@ pub trait CloseTransaction: StartTransaction { fn rollback(&mut self) -> impl std::future::Future>; } -pub trait Cursor { - fn build_cursor_start_qs( - &self, - cursor_name: &str, - scroll: &Option, - querystring: &str, - ) -> String { - let mut cursor_init_query = format!("DECLARE {cursor_name}"); - if let Some(scroll) = scroll { - if *scroll { - cursor_init_query.push_str(" SCROLL"); - } else { - cursor_init_query.push_str(" NO SCROLL"); - } - } - - cursor_init_query.push_str(format!(" CURSOR FOR {querystring}").as_str()); - - cursor_init_query - } - - fn start_cursor( - &mut self, - cursor_name: &str, - scroll: &Option, - querystring: String, - prepared: &Option, - parameters: Option>, - ) -> impl std::future::Future>; - - fn close_cursor( - &mut self, - cursor_name: &str, - ) -> impl std::future::Future>; -} +// pub trait Cursor { +// fn build_cursor_start_qs( +// &self, +// cursor_name: &str, +// scroll: &Option, +// querystring: &str, +// ) -> String { +// let mut cursor_init_query = format!("DECLARE {cursor_name}"); +// if let Some(scroll) = scroll { +// if *scroll { +// cursor_init_query.push_str(" SCROLL"); +// } else { +// cursor_init_query.push_str(" NO SCROLL"); +// } +// } + +// cursor_init_query.push_str(format!(" CURSOR FOR {querystring}").as_str()); + +// cursor_init_query +// } + +// fn start_cursor( +// &mut self, +// cursor_name: &str, +// scroll: &Option, +// querystring: String, +// prepared: &Option, +// parameters: Option>, +// ) -> impl std::future::Future>; + +// fn close_cursor( +// &mut self, +// cursor_name: &str, +// ) -> impl std::future::Future>; +// } diff --git a/src/driver/common.rs b/src/driver/common.rs index 528fed84..ac447184 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -1,10 +1,23 @@ -use pyo3::prelude::*; use tokio_postgres::config::Host; use std::net::IpAddr; use super::{connection::Connection, cursor::Cursor, transaction::Transaction}; +use pyo3::{pymethods, Py, PyAny}; + +use crate::{ + connection::traits::CloseTransaction, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, +}; + +use bytes::BytesMut; +use futures_util::pin_mut; +use pyo3::{buffer::PyBuffer, PyErr, Python}; +use tokio_postgres::binary_copy::BinaryCopyInWriter; + +use crate::format_helpers::quote_ident; + macro_rules! impl_config_py_methods { ($name:ident) => { #[pymethods] @@ -92,3 +105,153 @@ macro_rules! impl_config_py_methods { impl_config_py_methods!(Transaction); impl_config_py_methods!(Connection); impl_config_py_methods!(Cursor); +// impl_config_py_methods!(Portal); + +macro_rules! impl_is_closed_method { + ($name:ident) => { + #[pymethods] + impl $name { + fn is_closed(&self) -> bool { + if self.conn.is_some() { + return true; + } + false + } + } + }; +} + +impl_is_closed_method!(Transaction); +impl_is_closed_method!(Connection); +impl_is_closed_method!(Cursor); + +macro_rules! impl_portal_method { + ($name:ident) => { + #[pymethods] + impl $name { + #[pyo3(signature = (querystring, parameters=None, fetch_number=None))] + pub fn cursor( + &self, + querystring: Option, + parameters: Option>, + fetch_number: Option, + ) -> PSQLPyResult { + Ok(Cursor::new( + self.conn.clone(), + querystring, + parameters, + fetch_number, + self.pg_config.clone(), + None, + )) + } + } + }; +} + +impl_portal_method!(Transaction); +impl_portal_method!(Connection); + +macro_rules! impl_transaction_methods { + ($name:ident, $val:expr $(,)?) => { + #[pymethods] + impl $name { + pub async fn commit(&mut self) -> PSQLPyResult<()> { + let conn = self.conn.clone(); + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError("1".into())); + }; + let mut write_conn_g = conn.write().await; + write_conn_g.commit().await?; + + if $val { + self.conn = None; + } + + Ok(()) + } + + pub async fn rollback(&mut self) -> PSQLPyResult<()> { + let conn = self.conn.clone(); + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError("2".into())); + }; + let mut write_conn_g = conn.write().await; + write_conn_g.rollback().await?; + + if $val { + self.conn = None; + } + + Ok(()) + } + } + }; +} + +impl_transaction_methods!(Connection, false); +impl_transaction_methods!(Transaction, true); + +macro_rules! impl_binary_copy_method { + ($name:ident) => { + #[pymethods] + impl $name { + #[pyo3(signature = (source, table_name, columns=None, schema_name=None))] + pub async fn binary_copy_to_table( + self_: pyo3::Py, + source: Py, + table_name: String, + columns: Option>, + schema_name: Option, + ) -> PSQLPyResult { + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); + let mut table_name = quote_ident(&table_name); + if let Some(schema_name) = schema_name { + table_name = format!("{}.{}", quote_ident(&schema_name), table_name); + } + + let mut formated_columns = String::default(); + if let Some(columns) = columns { + formated_columns = format!("({})", columns.join(", ")); + } + + let copy_qs = + format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); + + if let Some(db_client) = db_client { + let mut psql_bytes: BytesMut = Python::with_gil(|gil| { + let possible_py_buffer: Result, PyErr> = + source.extract::>(gil); + if let Ok(py_buffer) = possible_py_buffer { + let vec_buf = py_buffer.to_vec(gil)?; + return Ok(BytesMut::from(vec_buf.as_slice())); + } + + if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { + if let Ok(bytes) = py_bytes.extract::>(gil) { + return Ok(BytesMut::from(bytes.as_slice())); + } + } + + Err(RustPSQLDriverError::PyToRustValueConversionError( + "source must be bytes or support Buffer protocol".into(), + )) + })?; + + let read_conn_g = db_client.read().await; + let sink = read_conn_g.copy_in(©_qs).await?; + let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); + pin_mut!(writer); + writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; + let rows_created = writer.as_mut().finish_empty().await?; + return Ok(rows_created); + } + + Ok(0) + } + } + }; +} + +impl_binary_copy_method!(Connection); +impl_binary_copy_method!(Transaction); diff --git a/src/driver/connection.rs b/src/driver/connection.rs index fa480386..3486141d 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -1,10 +1,8 @@ -use bytes::BytesMut; use deadpool_postgres::Pool; -use futures_util::pin_mut; -use pyo3::{buffer::PyBuffer, pyclass, pyfunction, pymethods, Py, PyAny, PyErr, Python}; +use pyo3::{ffi::PyObject, pyclass, pyfunction, pymethods, Py, PyAny, PyErr}; use std::sync::Arc; use tokio::sync::RwLock; -use tokio_postgres::{binary_copy::BinaryCopyInWriter, Config}; +use tokio_postgres::Config; use crate::{ connection::{ @@ -12,14 +10,13 @@ use crate::{ traits::Connection as _, }, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, - format_helpers::quote_ident, options::{IsolationLevel, LoadBalanceHosts, ReadVariant, SslMode, TargetSessionAttrs}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, runtime::tokio_runtime, }; use super::{ - connection_pool::connect_pool, cursor::Cursor, portal::Portal, transaction::Transaction, + connection_pool::connect_pool, prepared_statement::PreparedStatement, transaction::Transaction, }; /// Make new connection pool. @@ -117,9 +114,9 @@ pub async fn connect( } #[pyclass(subclass)] -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Connection { - db_client: Option>>, + pub conn: Option>>, db_pool: Option, pub pg_config: Arc, } @@ -127,12 +124,12 @@ pub struct Connection { impl Connection { #[must_use] pub fn new( - db_client: Option>>, + conn: Option>>, db_pool: Option, pg_config: Arc, ) -> Self { Connection { - db_client, + conn, db_pool, pg_config, } @@ -140,7 +137,7 @@ impl Connection { #[must_use] pub fn db_client(&self) -> Option>> { - self.db_client.clone() + self.conn.clone() } #[must_use] @@ -157,11 +154,17 @@ impl Default for Connection { #[pymethods] impl Connection { + async fn in_transaction(&self) -> bool { + let Some(conn) = &self.conn else { return false }; + let read_conn_g = conn.read().await; + read_conn_g.in_transaction() + } + async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { let (db_client, db_pool, pg_config) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); ( - self_.db_client.clone(), + self_.conn.clone(), self_.db_pool.clone(), self_.pg_config.clone(), ) @@ -179,7 +182,7 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(RwLock::new(PSQLPyConnection::PoolConn( + self_.conn = Some(Arc::new(RwLock::new(PSQLPyConnection::PoolConn( PoolConnection::new(connection, pg_config), )))); }); @@ -206,7 +209,7 @@ impl Connection { pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - std::mem::take(&mut self_.db_client); + std::mem::take(&mut self_.conn); std::mem::take(&mut self_.db_pool); if is_exception_none { @@ -232,7 +235,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; @@ -256,7 +259,7 @@ impl Connection { /// 1) Connection is closed. /// 2) Cannot execute querystring. pub async fn execute_batch(self_: pyo3::Py, querystring: String) -> PSQLPyResult<()> { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; @@ -282,14 +285,17 @@ impl Connection { querystring: String, parameters: Option>>, prepared: Option, - ) -> PSQLPyResult<()> { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + ) -> PSQLPyResult> { + let (db_client, py_none) = + pyo3::Python::with_gil(|gil| (self_.borrow(gil).conn.clone(), gil.None().into_any())); if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; - return read_conn_g + read_conn_g .execute_many(querystring, parameters, prepared) - .await; + .await?; + + return Ok(py_none); } Err(RustPSQLDriverError::ConnectionClosedError) @@ -310,7 +316,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; @@ -341,7 +347,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; @@ -371,7 +377,7 @@ impl Connection { parameters: Option>, prepared: Option, ) -> PSQLPyResult> { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); if let Some(db_client) = db_client { let read_conn_g = db_client.read().await; @@ -398,7 +404,7 @@ impl Connection { read_variant: Option, deferrable: Option, ) -> PSQLPyResult { - let Some(conn) = &self.db_client else { + let Some(conn) = &self.conn else { return Err(RustPSQLDriverError::ConnectionClosedError); }; Ok(Transaction::new( @@ -410,131 +416,38 @@ impl Connection { )) } - /// Create new cursor object. - /// - /// # Errors - /// May return Err Result if db_client is None. #[pyo3(signature = ( querystring, parameters=None, - fetch_number=None, - scroll=None, - prepared=None, ))] - pub fn cursor( + pub async fn prepare( &self, querystring: String, - parameters: Option>, - fetch_number: Option, - scroll: Option, - prepared: Option, - ) -> PSQLPyResult { - let Some(conn) = &self.db_client else { + parameters: Option>, + ) -> PSQLPyResult { + let Some(conn) = &self.conn else { return Err(RustPSQLDriverError::ConnectionClosedError); }; - Ok(Cursor::new( - conn.clone(), - self.pg_config.clone(), - querystring, - parameters, - fetch_number.unwrap_or(10), - scroll, - prepared, - )) - } + let read_conn_g = conn.read().await; + let prep_stmt = read_conn_g + .prepare_statement(querystring, parameters) + .await?; - #[pyo3(signature = ( - querystring, - parameters=None, - fetch_number=None, - ))] - pub fn portal( - &self, - querystring: String, - parameters: Option>, - fetch_number: Option, - ) -> PSQLPyResult { - println!("{:?}", fetch_number); - Ok(Portal::new( - self.db_client.clone(), - querystring, - parameters, - fetch_number, + Ok(PreparedStatement::new( + self.conn.clone(), + self.pg_config.clone(), + prep_stmt, )) } #[allow(clippy::needless_pass_by_value)] - pub fn back_to_pool(self_: pyo3::Py) { + pub fn close(self_: pyo3::Py) { pyo3::Python::with_gil(|gil| { let mut connection = self_.borrow_mut(gil); - if connection.db_client.is_some() { - std::mem::take(&mut connection.db_client); + if connection.conn.is_some() { + std::mem::take(&mut connection.conn); } }); } - - /// Perform binary copy to postgres table. - /// - /// # Errors - /// May return Err Result if cannot get bytes, - /// cannot perform request to the database, - /// cannot write bytes to the database. - #[pyo3(signature = ( - source, - table_name, - columns=None, - schema_name=None, - ))] - pub async fn binary_copy_to_table( - self_: pyo3::Py, - source: Py, - table_name: String, - columns: Option>, - schema_name: Option, - ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); - let mut table_name = quote_ident(&table_name); - if let Some(schema_name) = schema_name { - table_name = format!("{}.{}", quote_ident(&schema_name), table_name); - } - - let mut formated_columns = String::default(); - if let Some(columns) = columns { - formated_columns = format!("({})", columns.join(", ")); - } - - let copy_qs = format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); - - if let Some(db_client) = db_client { - let mut psql_bytes: BytesMut = Python::with_gil(|gil| { - let possible_py_buffer: Result, PyErr> = - source.extract::>(gil); - if let Ok(py_buffer) = possible_py_buffer { - let vec_buf = py_buffer.to_vec(gil)?; - return Ok(BytesMut::from(vec_buf.as_slice())); - } - - if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { - if let Ok(bytes) = py_bytes.extract::>(gil) { - return Ok(BytesMut::from(bytes.as_slice())); - } - } - - Err(RustPSQLDriverError::PyToRustValueConversionError( - "source must be bytes or support Buffer protocol".into(), - )) - })?; - - let read_conn_g = db_client.read().await; - let sink = read_conn_g.copy_in(©_qs).await?; - let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); - pin_mut!(writer); - writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; - let rows_created = writer.as_mut().finish_empty().await?; - return Ok(rows_created); - } - - Ok(0) - } } diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index a12d2bfa..ab40526a 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -1,103 +1,88 @@ -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; +use std::sync::Arc; use pyo3::{ - exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, + exceptions::PyStopAsyncIteration, pyclass, pymethods, types::PyNone, Py, PyAny, PyErr, + PyObject, Python, }; use tokio::sync::RwLock; -use tokio_postgres::Config; +use tokio_postgres::{Config, Portal as tp_Portal}; use crate::{ - connection::{structs::PSQLPyConnection, traits::Cursor as _}, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, query_result::PSQLDriverPyQueryResult, runtime::rustdriver_future, + statement::statement::PsqlpyStatement, + transaction::structs::PSQLPyTransaction, }; -static NEXT_CUR_ID: AtomicUsize = AtomicUsize::new(0); - -fn next_cursor_name() -> String { - format!("cur{}", NEXT_CUR_ID.fetch_add(1, Ordering::SeqCst),) -} +use crate::connection::structs::PSQLPyConnection; -#[pyclass(subclass)] +#[pyclass] pub struct Cursor { - conn: Option>>, - pub pg_config: Arc, - querystring: String, + pub conn: Option>>, + querystring: Option, parameters: Option>, - cursor_name: Option, - fetch_number: usize, - scroll: Option, - prepared: Option, + array_size: i32, + + statement: Option, + + transaction: Option>, + inner: Option, + + pub pg_config: Arc, } impl Cursor { pub fn new( - conn: Arc>, - pg_config: Arc, - querystring: String, + conn: Option>>, + querystring: Option, parameters: Option>, - fetch_number: usize, - scroll: Option, - prepared: Option, + array_size: Option, + pg_config: Arc, + statement: Option, ) -> Self { - Cursor { - conn: Some(conn), - pg_config, + Self { + conn, + transaction: None, + inner: None, querystring, parameters, - cursor_name: None, - fetch_number, - scroll, - prepared, + array_size: array_size.unwrap_or(1), + pg_config, + statement, } } - async fn execute(&self, querystring: &str) -> PSQLPyResult { - let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::CursorClosedError); + async fn query_portal(&self, size: i32) -> PSQLPyResult { + let Some(transaction) = &self.transaction else { + return Err(RustPSQLDriverError::TransactionClosedError("3".into())); + }; + let Some(portal) = &self.inner else { + return Err(RustPSQLDriverError::TransactionClosedError("4".into())); }; - let read_conn_g = conn.read().await; - - let result = read_conn_g - .execute(querystring.to_string(), None, Some(false)) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - - Ok(result) + transaction.query_portal(&portal, size).await } } -#[pymethods] -impl Cursor { - #[getter] - fn cursor_name(&self) -> Option { - return self.cursor_name.clone(); - } - - #[getter] - fn querystring(&self) -> String { - return self.querystring.clone(); +impl Drop for Cursor { + fn drop(&mut self) { + self.transaction = None; + self.conn = None; } +} +#[pymethods] +impl Cursor { #[getter] - fn parameters(&self) -> Option> { - return self.parameters.clone(); + fn get_array_size(&self) -> i32 { + self.array_size } - #[getter] - fn prepared(&self) -> Option { - return self.prepared.clone(); + #[setter] + fn set_array_size(&mut self, value: i32) { + self.array_size = value; } - #[must_use] fn __aiter__(slf: Py) -> Py { slf } @@ -107,17 +92,13 @@ impl Cursor { } async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { - let cursor_name = next_cursor_name(); - - let (conn, scroll, querystring, prepared, parameters) = Python::with_gil(|gil| { - let mut self_ = slf.borrow_mut(gil); - self_.cursor_name = Some(cursor_name.clone()); + let (conn, querystring, parameters, statement) = Python::with_gil(|gil| { + let self_ = slf.borrow(gil); ( self_.conn.clone(), - self_.scroll, self_.querystring.clone(), - self_.prepared, self_.parameters.clone(), + self_.statement.clone(), ) }); @@ -126,15 +107,28 @@ impl Cursor { }; let mut write_conn_g = conn.write().await; - write_conn_g - .start_cursor( - &cursor_name, - &scroll, - querystring.clone(), - &prepared, - parameters.clone(), - ) - .await?; + let (txid, inner_portal) = match querystring { + Some(querystring) => { + write_conn_g + .portal(Some(&querystring), ¶meters, None) + .await? + } + None => { + let Some(statement) = statement else { + return Err(RustPSQLDriverError::CursorStartError( + "Cannot start cursor".into(), + )); + }; + write_conn_g.portal(None, &None, Some(&statement)).await? + } + }; + + Python::with_gil(|gil| { + let mut self_ = slf.borrow_mut(gil); + + self_.transaction = Some(Arc::new(txid)); + self_.inner = Some(inner_portal); + }); Ok(slf) } @@ -146,7 +140,7 @@ impl Cursor { exception: Py, _traceback: Py, ) -> PSQLPyResult<()> { - self.close().await?; + self.close(); let (is_exc_none, py_err) = pyo3::Python::with_gil(|gil| { ( @@ -162,33 +156,27 @@ impl Cursor { } fn __anext__(&self) -> PSQLPyResult> { - let conn = self.conn.clone(); - let fetch_number = self.fetch_number; - let Some(cursor_name) = self.cursor_name.clone() else { - return Err(RustPSQLDriverError::CursorClosedError); - }; + let txid = self.transaction.clone(); + let portal = self.inner.clone(); + let size = self.array_size.clone(); let py_future = Python::with_gil(move |gil| { rustdriver_future(gil, async move { - let Some(conn) = conn else { - return Err(RustPSQLDriverError::CursorClosedError); + let Some(txid) = &txid else { + return Err(RustPSQLDriverError::TransactionClosedError("5".into())); }; - - let read_conn_g = conn.read().await; - let result = read_conn_g - .execute( - format!("FETCH {fetch_number} FROM {cursor_name}"), - None, - Some(false), - ) - .await?; + let Some(portal) = &portal else { + return Err(RustPSQLDriverError::TransactionClosedError("6".into())); + }; + let result = txid.query_portal(&portal, size).await?; if result.is_empty() { return Err(PyStopAsyncIteration::new_err( - "Iteration is over, no more results in cursor", + "Iteration is over, no more results in portal", ) .into()); }; + Ok(result) }) }); @@ -196,148 +184,66 @@ impl Cursor { Ok(Some(py_future?)) } - pub async fn start(&mut self) -> PSQLPyResult<()> { - if self.cursor_name.is_some() { - return Ok(()); - } - + async fn start(&mut self) -> PSQLPyResult<()> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::CursorClosedError); + return Err(RustPSQLDriverError::ConnectionClosedError); }; let mut write_conn_g = conn.write().await; - let cursor_name = next_cursor_name(); - - write_conn_g - .start_cursor( - &cursor_name, - &self.scroll, - self.querystring.clone(), - &self.prepared, - self.parameters.clone(), - ) - .await?; - - self.cursor_name = Some(cursor_name); - - Ok(()) - } - - pub async fn close(&mut self) -> PSQLPyResult<()> { - if let Some(cursor_name) = &self.cursor_name { - let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - let mut write_conn_g = conn.write().await; - write_conn_g.close_cursor(&cursor_name).await?; - self.cursor_name = None; + let (txid, inner_portal) = match &self.querystring { + Some(querystring) => { + write_conn_g + .portal(Some(&querystring), &self.parameters, None) + .await? + } + None => { + let Some(statement) = &self.statement else { + return Err(RustPSQLDriverError::CursorStartError( + "Cannot start cursor".into(), + )); + }; + write_conn_g.portal(None, &None, Some(&statement)).await? + } }; - self.conn = None; + self.transaction = Some(Arc::new(txid)); + self.inner = Some(inner_portal); Ok(()) } - #[pyo3(signature = (fetch_number=None))] - pub async fn fetch( - &self, - fetch_number: Option, - ) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!( - "FETCH {} FROM {}", - fetch_number.unwrap_or(self.fetch_number), - cursor_name, - )) - .await - } - - pub async fn fetch_next(&self) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!("FETCH NEXT FROM {cursor_name}")) - .await - } - - pub async fn fetch_prior(&self) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!("FETCH PRIOR FROM {cursor_name}")) - .await - } - - pub async fn fetch_first(&self) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!("FETCH FIRST FROM {cursor_name}")) - .await + fn close(&mut self) { + self.transaction = None; + self.conn = None; } - pub async fn fetch_last(&self) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!("FETCH LAST FROM {cursor_name}")) - .await - } + #[pyo3(signature = ( + querystring, + parameters=None, + ))] + async fn execute( + &mut self, + querystring: String, + parameters: Option>, + ) -> PSQLPyResult<()> { + self.querystring = Some(querystring); + self.parameters = parameters; - pub async fn fetch_absolute( - &self, - absolute_number: i64, - ) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!( - "FETCH ABSOLUTE {absolute_number} FROM {cursor_name}" - )) - .await - } + self.start().await?; - pub async fn fetch_relative( - &self, - relative_number: i64, - ) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!( - "FETCH RELATIVE {relative_number} FROM {cursor_name}" - )) - .await + Ok(()) } - pub async fn fetch_forward_all(&self) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!("FETCH FORWARD ALL FROM {cursor_name}")) - .await + async fn fetchone(&self) -> PSQLPyResult { + self.query_portal(1).await } - pub async fn fetch_backward( - &self, - backward_count: i64, - ) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!( - "FETCH BACKWARD {backward_count} FROM {cursor_name}" - )) - .await + #[pyo3(signature = (size=None))] + async fn fetchmany(&self, size: Option) -> PSQLPyResult { + self.query_portal(size.unwrap_or(self.array_size)).await } - pub async fn fetch_backward_all(&self) -> PSQLPyResult { - let Some(cursor_name) = &self.cursor_name else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - self.execute(&format!("FETCH BACKWARD ALL FROM {cursor_name}")) - .await + async fn fetchall(&self) -> PSQLPyResult { + self.query_portal(-1).await } } diff --git a/src/driver/mod.rs b/src/driver/mod.rs index ab1b149c..30fec7c7 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -4,6 +4,6 @@ pub mod connection_pool; pub mod connection_pool_builder; pub mod cursor; pub mod listener; -pub mod portal; +pub mod prepared_statement; pub mod transaction; pub mod utils; diff --git a/src/driver/portal.rs b/src/driver/portal.rs deleted file mode 100644 index f6b4d755..00000000 --- a/src/driver/portal.rs +++ /dev/null @@ -1,195 +0,0 @@ -use std::sync::Arc; - -use pyo3::{ - exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, -}; -use tokio::sync::RwLock; -use tokio_postgres::Portal as tp_Portal; - -use crate::{ - exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, - query_result::PSQLDriverPyQueryResult, - runtime::rustdriver_future, - transaction::structs::PSQLPyTransaction, -}; - -use crate::connection::structs::PSQLPyConnection; - -#[pyclass] -pub struct Portal { - conn: Option>>, - querystring: String, - parameters: Option>, - array_size: i32, - - transaction: Option>, - inner: Option, -} - -impl Portal { - pub fn new( - conn: Option>>, - querystring: String, - parameters: Option>, - array_size: Option, - ) -> Self { - Self { - conn, - transaction: None, - inner: None, - querystring, - parameters, - array_size: array_size.unwrap_or(1), - } - } - - async fn query_portal(&self, size: i32) -> PSQLPyResult { - let Some(transaction) = &self.transaction else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - let Some(portal) = &self.inner else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - transaction.query_portal(&portal, size).await - } -} - -impl Drop for Portal { - fn drop(&mut self) { - self.transaction = None; - self.conn = None; - } -} - -#[pymethods] -impl Portal { - #[getter] - fn get_array_size(&self) -> i32 { - self.array_size - } - - #[setter] - fn set_array_size(&mut self, value: i32) { - self.array_size = value; - } - - fn __aiter__(slf: Py) -> Py { - slf - } - - fn __await__(slf: Py) -> Py { - slf - } - - async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { - let (conn, querystring, parameters) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - ( - self_.conn.clone(), - self_.querystring.clone(), - self_.parameters.clone(), - ) - }); - - let Some(conn) = conn else { - return Err(RustPSQLDriverError::CursorClosedError); - }; - let mut write_conn_g = conn.write().await; - - let (txid, inner_portal) = write_conn_g.portal(querystring, parameters).await?; - - Python::with_gil(|gil| { - let mut self_ = slf.borrow_mut(gil); - - self_.transaction = Some(Arc::new(txid)); - self_.inner = Some(inner_portal); - }); - - Ok(slf) - } - - #[allow(clippy::needless_pass_by_value)] - async fn __aexit__<'a>( - &mut self, - _exception_type: Py, - exception: Py, - _traceback: Py, - ) -> PSQLPyResult<()> { - self.close(); - - let (is_exc_none, py_err) = pyo3::Python::with_gil(|gil| { - ( - exception.is_none(gil), - PyErr::from_value(exception.into_bound(gil)), - ) - }); - - if !is_exc_none { - return Err(RustPSQLDriverError::RustPyError(py_err)); - } - Ok(()) - } - - fn __anext__(&self) -> PSQLPyResult> { - let txid = self.transaction.clone(); - let portal = self.inner.clone(); - let size = self.array_size.clone(); - - let py_future = Python::with_gil(move |gil| { - rustdriver_future(gil, async move { - let Some(txid) = &txid else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - let Some(portal) = &portal else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - let result = txid.query_portal(&portal, size).await?; - - if result.is_empty() { - return Err(PyStopAsyncIteration::new_err( - "Iteration is over, no more results in portal", - ) - .into()); - }; - - Ok(result) - }) - }); - - Ok(Some(py_future?)) - } - - async fn start(&mut self) -> PSQLPyResult<()> { - let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::ConnectionClosedError); - }; - let mut write_conn_g = conn.write().await; - - let (txid, inner_portal) = write_conn_g - .portal(self.querystring.clone(), self.parameters.clone()) - .await?; - - self.transaction = Some(Arc::new(txid)); - self.inner = Some(inner_portal); - - Ok(()) - } - - async fn fetch_one(&self) -> PSQLPyResult { - self.query_portal(1).await - } - - #[pyo3(signature = (size=None))] - async fn fetch_many(&self, size: Option) -> PSQLPyResult { - self.query_portal(size.unwrap_or(self.array_size)).await - } - - async fn fetch_all(&self) -> PSQLPyResult { - self.query_portal(-1).await - } - - fn close(&mut self) { - self.transaction = None; - self.conn = None; - } -} diff --git a/src/driver/prepared_statement.rs b/src/driver/prepared_statement.rs new file mode 100644 index 00000000..1880449c --- /dev/null +++ b/src/driver/prepared_statement.rs @@ -0,0 +1,63 @@ +use std::sync::Arc; + +use pyo3::{pyclass, pymethods}; +use tokio::sync::RwLock; +use tokio_postgres::Config; + +use crate::{ + connection::structs::PSQLPyConnection, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + query_result::PSQLDriverPyQueryResult, + statement::{parameters::Column, statement::PsqlpyStatement}, +}; + +use super::cursor::Cursor; + +#[pyclass(subclass)] +#[derive(Debug)] +pub struct PreparedStatement { + pub conn: Option>>, + pub pg_config: Arc, + statement: PsqlpyStatement, +} + +impl PreparedStatement { + pub fn new( + conn: Option>>, + pg_config: Arc, + statement: PsqlpyStatement, + ) -> Self { + Self { + conn, + pg_config, + statement, + } + } +} + +#[pymethods] +impl PreparedStatement { + async fn execute(&self) -> PSQLPyResult { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError("12".into())); + }; + + let read_conn_g = conn.read().await; + read_conn_g.execute_statement(&self.statement).await + } + + fn cursor(&self) -> PSQLPyResult { + Ok(Cursor::new( + self.conn.clone(), + None, + None, + None, + self.pg_config.clone(), + Some(self.statement.clone()), + )) + } + + fn columns(&self) -> Vec { + self.statement.columns().clone() + } +} diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index c779837f..81845d40 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use bytes::BytesMut; -use futures::{future, pin_mut}; +use futures::future; use pyo3::{ - buffer::PyBuffer, pyclass, pymethods, types::{PyAnyMethods, PyList, PyTuple}, Py, PyAny, PyErr, PyResult, }; use tokio::sync::RwLock; -use tokio_postgres::{binary_copy::BinaryCopyInWriter, Config}; +use tokio_postgres::Config; use crate::{ connection::{ @@ -17,14 +15,12 @@ use crate::{ traits::{CloseTransaction, Connection, StartTransaction as _}, }, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, - format_helpers::quote_ident, options::{IsolationLevel, ReadVariant}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, }; -use super::cursor::Cursor; - #[pyclass(subclass)] +#[derive(Debug)] pub struct Transaction { pub conn: Option>>, pub pg_config: Arc, @@ -75,7 +71,7 @@ impl Transaction { }); let Some(conn) = conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("7".into())); }; let mut write_conn_g = conn.write().await; write_conn_g @@ -102,7 +98,7 @@ impl Transaction { }); let Some(conn) = conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("8".into())); }; let mut write_conn_g = conn.write().await; if is_exception_none { @@ -125,7 +121,7 @@ impl Transaction { pub async fn begin(&mut self) -> PSQLPyResult<()> { let conn = &self.conn; let Some(conn) = conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("9".into())); }; let mut write_conn_g = conn.write().await; write_conn_g @@ -135,32 +131,6 @@ impl Transaction { Ok(()) } - pub async fn commit(&mut self) -> PSQLPyResult<()> { - let conn = self.conn.clone(); - let Some(conn) = conn else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - let mut write_conn_g = conn.write().await; - write_conn_g.commit().await?; - - self.conn = None; - - Ok(()) - } - - pub async fn rollback(&mut self) -> PSQLPyResult<()> { - let conn = self.conn.clone(); - let Some(conn) = conn else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - let mut write_conn_g = conn.write().await; - write_conn_g.rollback().await?; - - self.conn = None; - - Ok(()) - } - #[pyo3(signature = (querystring, parameters=None, prepared=None))] pub async fn execute( &self, @@ -169,7 +139,7 @@ impl Transaction { prepared: Option, ) -> PSQLPyResult { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("10".into())); }; let read_conn_g = conn.read().await; @@ -184,7 +154,7 @@ impl Transaction { prepared: Option, ) -> PSQLPyResult { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("11".into())); }; let read_conn_g = conn.read().await; @@ -199,7 +169,7 @@ impl Transaction { prepared: Option, ) -> PSQLPyResult> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("12".into())); }; let read_conn_g = conn.read().await; @@ -210,7 +180,7 @@ impl Transaction { pub async fn execute_batch(&self, querystring: String) -> PSQLPyResult<()> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("13".into())); }; let read_conn_g = conn.read().await; @@ -225,7 +195,7 @@ impl Transaction { prepared: Option, ) -> PSQLPyResult<()> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("14".into())); }; let read_conn_g = conn.read().await; @@ -242,7 +212,7 @@ impl Transaction { prepared: Option, ) -> PSQLPyResult { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("15".into())); }; let read_conn_g = conn.read().await; @@ -253,7 +223,7 @@ impl Transaction { pub async fn create_savepoint(&mut self, savepoint_name: String) -> PSQLPyResult<()> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("16".into())); }; let read_conn_g = conn.read().await; @@ -266,7 +236,7 @@ impl Transaction { pub async fn release_savepoint(&mut self, savepoint_name: String) -> PSQLPyResult<()> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("17".into())); }; let read_conn_g = conn.read().await; @@ -279,7 +249,7 @@ impl Transaction { pub async fn rollback_savepoint(&mut self, savepoint_name: String) -> PSQLPyResult<()> { let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); + return Err(RustPSQLDriverError::TransactionClosedError("18".into())); }; let read_conn_g = conn.read().await; @@ -290,39 +260,6 @@ impl Transaction { Ok(()) } - /// Create new cursor object. - /// - /// # Errors - /// May return Err Result if db_client is None - #[pyo3(signature = ( - querystring, - parameters=None, - fetch_number=None, - scroll=None, - prepared=None, - ))] - pub fn cursor( - &self, - querystring: String, - parameters: Option>, - fetch_number: Option, - scroll: Option, - prepared: Option, - ) -> PSQLPyResult { - let Some(conn) = &self.conn else { - return Err(RustPSQLDriverError::TransactionClosedError); - }; - Ok(Cursor::new( - conn.clone(), - self.pg_config.clone(), - querystring, - parameters, - fetch_number.unwrap_or(10), - scroll, - prepared, - )) - } - #[pyo3(signature = (queries=None, prepared=None))] pub async fn pipeline<'py>( self_: Py, @@ -368,65 +305,6 @@ impl Transaction { return future::try_join_all(futures).await; } - Err(RustPSQLDriverError::TransactionClosedError) - } - - /// Perform binary copy to postgres table. - /// - /// # Errors - /// May return Err Result if cannot get bytes, - /// cannot perform request to the database, - /// cannot write bytes to the database. - #[pyo3(signature = (source, table_name, columns=None, schema_name=None))] - pub async fn binary_copy_to_table( - self_: pyo3::Py, - source: Py, - table_name: String, - columns: Option>, - schema_name: Option, - ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); - let mut table_name = quote_ident(&table_name); - if let Some(schema_name) = schema_name { - table_name = format!("{}.{}", quote_ident(&schema_name), table_name); - } - - let mut formated_columns = String::default(); - if let Some(columns) = columns { - formated_columns = format!("({})", columns.join(", ")); - } - - let copy_qs = format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); - - if let Some(db_client) = db_client { - let mut psql_bytes: BytesMut = pyo3::Python::with_gil(|gil| { - let possible_py_buffer: Result, PyErr> = - source.extract::>(gil); - if let Ok(py_buffer) = possible_py_buffer { - let vec_buf = py_buffer.to_vec(gil)?; - return Ok(BytesMut::from(vec_buf.as_slice())); - } - - if let Ok(py_bytes) = source.call_method0(gil, "getvalue") { - if let Ok(bytes) = py_bytes.extract::>(gil) { - return Ok(BytesMut::from(bytes.as_slice())); - } - } - - Err(RustPSQLDriverError::PyToRustValueConversionError( - "source must be bytes or support Buffer protocol".into(), - )) - })?; - - let read_conn_g = db_client.read().await; - let sink = read_conn_g.copy_in(©_qs).await?; - let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); - pin_mut!(writer); - writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; - let rows_created = writer.as_mut().finish_empty().await?; - return Ok(rows_created); - } - - Ok(0) + Err(RustPSQLDriverError::TransactionClosedError("19".into())) } } diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index f133321b..9062a37e 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -49,8 +49,8 @@ pub enum RustPSQLDriverError { TransactionSavepointError(String), #[error("Transaction execute error: {0}")] TransactionExecuteError(String), - #[error("Underlying connection is returned to the pool")] - TransactionClosedError, + #[error("Underlying connection is returned to the pool: {0}")] + TransactionClosedError(String), // Cursor Errors #[error("Cursor error: {0}")] @@ -162,7 +162,7 @@ impl From for pyo3::PyErr { RustPSQLDriverError::TransactionExecuteError(_) => { TransactionExecuteError::new_err((error_desc,)) } - RustPSQLDriverError::TransactionClosedError => { + RustPSQLDriverError::TransactionClosedError(_) => { TransactionClosedError::new_err((error_desc,)) } RustPSQLDriverError::BaseCursorError(_) => BaseCursorError::new_err((error_desc,)), diff --git a/src/lib.rs b/src/lib.rs index 3229e675..a20c1ce4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,8 +35,10 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_function(wrap_pyfunction!(driver::connection::connect, pymod)?)?; pymod.add_class::()?; + // pymod.add_class::()?; + pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; - pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; diff --git a/src/options.rs b/src/options.rs index bd8ad511..f6e4152f 100644 --- a/src/options.rs +++ b/src/options.rs @@ -141,7 +141,7 @@ impl CopyCommandFormat { } #[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] +#[derive(Clone, Copy, PartialEq, Debug)] pub enum IsolationLevel { ReadUncommitted, ReadCommitted, @@ -163,7 +163,7 @@ impl IsolationLevel { } #[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] +#[derive(Clone, Copy, PartialEq, Debug)] pub enum ReadVariant { ReadOnly, ReadWrite, diff --git a/src/query_result.rs b/src/query_result.rs index cda02a8b..d9dd8848 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -1,4 +1,9 @@ -use pyo3::{prelude::*, pyclass, pymethods, types::PyDict, Py, PyAny, Python, ToPyObject}; +use pyo3::{ + prelude::*, + pyclass, pymethods, + types::{PyDict, PyTuple}, + Py, PyAny, Python, ToPyObject, +}; use tokio_postgres::Row; use crate::{exceptions::rust_errors::PSQLPyResult, value_converter::to_python::postgres_to_py}; diff --git a/src/statement/cache.rs b/src/statement/cache.rs index 7d78898d..7c07da40 100644 --- a/src/statement/cache.rs +++ b/src/statement/cache.rs @@ -5,7 +5,7 @@ use postgres_types::Type; use tokio::sync::RwLock; use tokio_postgres::Statement; -use super::{query::QueryString, utils::hash_str}; +use super::{parameters::Column, query::QueryString, utils::hash_str}; #[derive(Default)] pub(crate) struct StatementsCache(HashMap); @@ -44,6 +44,14 @@ impl StatementCacheInfo { pub(crate) fn types(&self) -> Vec { self.inner_stmt.params().to_vec() } + + pub(crate) fn columns(&self) -> Vec { + self.inner_stmt + .columns() + .iter() + .map(|column| Column::new(column.name().to_string(), column.table_oid().clone())) + .collect::>() + } } pub(crate) static STMTS_CACHE: Lazy> = diff --git a/src/statement/parameters.rs b/src/statement/parameters.rs index 09e0cbef..3aa12160 100644 --- a/src/statement/parameters.rs +++ b/src/statement/parameters.rs @@ -3,6 +3,7 @@ use std::iter::zip; use postgres_types::{ToSql, Type}; use pyo3::{ conversion::FromPyObjectBound, + pyclass, pymethods, types::{PyAnyMethods, PyMapping}, Py, PyObject, PyTypeCheck, Python, }; @@ -17,16 +18,48 @@ use crate::{ pub type QueryParameter = (dyn ToSql + Sync); +#[pyclass] +#[derive(Default, Clone, Debug)] +pub struct Column { + name: String, + table_oid: Option, +} + +impl Column { + pub fn new(name: String, table_oid: Option) -> Self { + Self { name, table_oid } + } +} + +#[pymethods] +impl Column { + #[getter] + fn get_name(&self) -> String { + self.name.clone() + } + + #[getter] + fn get_table_oid(&self) -> Option { + self.table_oid.clone() + } +} + pub(crate) struct ParametersBuilder { parameters: Option, types: Option>, + columns: Vec, } impl ParametersBuilder { - pub fn new(parameters: &Option, types: Option>) -> Self { + pub fn new( + parameters: &Option, + types: Option>, + columns: Vec, + ) -> Self { Self { parameters: parameters.clone(), types, + columns, } } @@ -55,13 +88,15 @@ impl ParametersBuilder { match (sequence_typed, mapping_typed) { (Some(sequence), None) => { - prepared_parameters = - Some(SequenceParametersBuilder::new(sequence, self.types).prepare(gil)?); + prepared_parameters = Some( + SequenceParametersBuilder::new(sequence, self.types, self.columns) + .prepare(gil)?, + ); } (None, Some(mapping)) => { if let Some(parameters_names) = parameters_names { prepared_parameters = Some( - MappingParametersBuilder::new(mapping, self.types) + MappingParametersBuilder::new(mapping, self.types, self.columns) .prepare(gil, parameters_names)?, ) } @@ -110,13 +145,15 @@ impl ParametersBuilder { pub(crate) struct MappingParametersBuilder { map_parameters: Py, types: Option>, + columns: Vec, } impl MappingParametersBuilder { - fn new(map_parameters: Py, types: Option>) -> Self { + fn new(map_parameters: Py, types: Option>, columns: Vec) -> Self { Self { map_parameters, types, + columns, } } @@ -143,7 +180,11 @@ impl MappingParametersBuilder { .map(|(parameter, type_)| from_python_typed(parameter.bind(gil), &type_)) .collect::>>()?; - Ok(PreparedParameters::new(converted_parameters, types)) + Ok(PreparedParameters::new( + converted_parameters, + types, + self.columns, + )) } fn prepare_not_typed( @@ -157,7 +198,11 @@ impl MappingParametersBuilder { .map(|parameter| from_python_untyped(parameter.bind(gil))) .collect::>>()?; - Ok(PreparedParameters::new(converted_parameters, vec![])) + Ok(PreparedParameters::new( + converted_parameters, + vec![], + self.columns, + )) } fn extract_parameters( @@ -185,13 +230,15 @@ impl MappingParametersBuilder { pub(crate) struct SequenceParametersBuilder { seq_parameters: Vec, types: Option>, + columns: Vec, } impl SequenceParametersBuilder { - fn new(seq_parameters: Vec, types: Option>) -> Self { + fn new(seq_parameters: Vec, types: Option>, columns: Vec) -> Self { Self { seq_parameters: seq_parameters, types, + columns, } } @@ -208,7 +255,11 @@ impl SequenceParametersBuilder { .map(|(parameter, type_)| from_python_typed(parameter.bind(gil), &type_)) .collect::>>()?; - Ok(PreparedParameters::new(converted_parameters, types)) + Ok(PreparedParameters::new( + converted_parameters, + types, + self.columns, + )) } fn prepare_not_typed(self, gil: Python<'_>) -> PSQLPyResult { @@ -218,7 +269,11 @@ impl SequenceParametersBuilder { .map(|parameter| from_python_untyped(parameter.bind(gil))) .collect::>>()?; - Ok(PreparedParameters::new(converted_parameters, vec![])) + Ok(PreparedParameters::new( + converted_parameters, + vec![], + self.columns, + )) } } @@ -226,11 +281,16 @@ impl SequenceParametersBuilder { pub struct PreparedParameters { parameters: Vec, types: Vec, + columns: Vec, } impl PreparedParameters { - pub fn new(parameters: Vec, types: Vec) -> Self { - Self { parameters, types } + pub fn new(parameters: Vec, types: Vec, columns: Vec) -> Self { + Self { + parameters, + types, + columns, + } } pub fn params(&self) -> Box<[&(dyn ToSql + Sync)]> { @@ -251,4 +311,8 @@ impl PreparedParameters { .collect::>() .into_boxed_slice() } + + pub fn columns(&self) -> &Vec { + &self.columns + } } diff --git a/src/statement/statement.rs b/src/statement/statement.rs index fd77eb55..fc45b3eb 100644 --- a/src/statement/statement.rs +++ b/src/statement/statement.rs @@ -3,7 +3,10 @@ use tokio_postgres::Statement; use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; -use super::{parameters::PreparedParameters, query::QueryString}; +use super::{ + parameters::{Column, PreparedParameters}, + query::QueryString, +}; #[derive(Clone, Debug)] pub struct PsqlpyStatement { @@ -47,4 +50,8 @@ impl PsqlpyStatement { pub fn params_typed(&self) -> Box<[(&(dyn ToSql + Sync), Type)]> { self.prepared_parameters.params_typed() } + + pub fn columns(&self) -> &Vec { + &self.prepared_parameters.columns() + } } diff --git a/src/statement/statement_builder.rs b/src/statement/statement_builder.rs index c909f68d..054352a3 100644 --- a/src/statement/statement_builder.rs +++ b/src/statement/statement_builder.rs @@ -9,22 +9,22 @@ use crate::{ use super::{ cache::{StatementCacheInfo, StatementsCache, STMTS_CACHE}, - parameters::ParametersBuilder, + parameters::{Column, ParametersBuilder}, query::QueryString, statement::PsqlpyStatement, }; pub struct StatementBuilder<'a> { - querystring: String, - parameters: Option, + querystring: &'a String, + parameters: &'a Option, inner_conn: &'a PSQLPyConnection, prepared: bool, } impl<'a> StatementBuilder<'a> { pub fn new( - querystring: String, - parameters: Option, + querystring: &'a String, + parameters: &'a Option, inner_conn: &'a PSQLPyConnection, prepared: Option, ) -> Self { @@ -51,7 +51,8 @@ impl<'a> StatementBuilder<'a> { } fn build_with_cached(self, cached: StatementCacheInfo) -> PSQLPyResult { - let raw_parameters = ParametersBuilder::new(&self.parameters, Some(cached.types())); + let raw_parameters = + ParametersBuilder::new(&self.parameters, Some(cached.types()), cached.columns()); let parameters_names = if let Some(converted_qs) = &cached.query.converted_qs { Some(converted_qs.params_names().clone()) @@ -76,8 +77,17 @@ impl<'a> StatementBuilder<'a> { querystring.process_qs(); let prepared_stmt = self.prepare_query(&querystring, self.prepared).await?; - let parameters_builder = - ParametersBuilder::new(&self.parameters, Some(prepared_stmt.params().to_vec())); + + let columns = prepared_stmt + .columns() + .iter() + .map(|column| Column::new(column.name().to_string(), column.table_oid().clone())) + .collect::>(); + let parameters_builder = ParametersBuilder::new( + &self.parameters, + Some(prepared_stmt.params().to_vec()), + columns, + ); let parameters_names = if let Some(converted_qs) = &querystring.converted_qs { Some(converted_qs.params_names().clone()) diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index c0801bac..6ee761bc 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -172,6 +172,12 @@ fn postgres_bytes_to_py( } Ok(py.None()) } + Type::OID => { + Ok(composite_field_postgres_to_py::>(type_, buf, is_simple)?.to_object(py)) + } + Type::NAME => Ok( + composite_field_postgres_to_py::>(type_, buf, is_simple)?.to_object(py), + ), // // ---------- String Types ---------- // // Convert TEXT and VARCHAR type into String, then into str Type::TEXT | Type::VARCHAR | Type::XML => Ok(composite_field_postgres_to_py::< @@ -342,6 +348,11 @@ fn postgres_bytes_to_py( composite_field_postgres_to_py::>>(type_, buf, is_simple)?, ) .to_object(py)), + Type::OID_ARRAY => Ok(postgres_array_to_py( + py, + composite_field_postgres_to_py::>>(type_, buf, is_simple)?, + ) + .to_object(py)), // Convert ARRAY of TEXT or VARCHAR into Vec, then into list[str] Type::TEXT_ARRAY | Type::VARCHAR_ARRAY | Type::XML_ARRAY => Ok(postgres_array_to_py( py,