diff --git a/Cargo.lock b/Cargo.lock index fc0b9c9b..076ac4c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -572,6 +572,15 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +[[package]] +name = "inventory" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b31349d02fe60f80bbbab1a9402364cad7460626d6030494b08ac4a2075bf81" +dependencies = [ + "rustversion", +] + [[package]] name = "itertools" version = "0.12.1" @@ -1061,6 +1070,7 @@ dependencies = [ "cfg-if", "chrono", "indoc", + "inventory", "libc", "memoffset", "once_cell", @@ -1291,6 +1301,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustversion" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" + [[package]] name = "ryu" version = "1.0.18" diff --git a/Cargo.toml b/Cargo.toml index b33f08c2..68160a6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] deadpool-postgres = { git = "https://github.com/chandr-andr/deadpool.git", branch = "psqlpy" } -pyo3 = { version = "0.23.4", features = ["chrono", "experimental-async", "rust_decimal", "py-clone", "macros"] } +pyo3 = { version = "0.23.4", features = ["chrono", "experimental-async", "rust_decimal", "py-clone", "macros", "multiple-pymethods"] } pyo3-async-runtimes = { git = "https://github.com/chandr-andr/pyo3-async-runtimes.git", branch = "psqlpy", features = [ "tokio-runtime", ] } diff --git a/python/tests/test_transaction.py b/python/tests/test_transaction.py index a6dfd191..a186084a 100644 --- a/python/tests/test_transaction.py +++ b/python/tests/test_transaction.py @@ -9,11 +9,10 @@ IsolationLevel, ReadVariant, ) +from psqlpy._internal.exceptions import TransactionClosedError from psqlpy.exceptions import ( InterfaceError, - TransactionBeginError, TransactionExecuteError, - TransactionSavepointError, ) from tests.helpers import count_rows_in_test_table @@ -64,10 +63,10 @@ async def test_transaction_begin( connection = await psql_pool.connection() transaction = connection.transaction() - with pytest.raises(expected_exception=TransactionBeginError): - await transaction.execute( - f"SELECT * FROM {table_name}", - ) + # with pytest.raises(expected_exception=TransactionBeginError): + await transaction.execute( + f"SELECT * FROM {table_name}", + ) await transaction.begin() @@ -170,7 +169,7 @@ async def test_transaction_rollback( await transaction.rollback() - with pytest.raises(expected_exception=TransactionBeginError): + with pytest.raises(expected_exception=TransactionClosedError): await transaction.execute( f"SELECT * FROM {table_name} WHERE name = $1", parameters=[test_name], @@ -198,9 +197,8 @@ async def test_transaction_release_savepoint( sp_name_2 = "sp2" await transaction.create_savepoint(sp_name_1) - - with pytest.raises(expected_exception=TransactionSavepointError): - await transaction.create_savepoint(sp_name_1) + # There is no problem in creating the same sp_name + await transaction.create_savepoint(sp_name_1) await transaction.create_savepoint(sp_name_2) diff --git a/src/connection/impls.rs b/src/connection/impls.rs index 50b195a0..ee6bab4b 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -12,15 +12,15 @@ use crate::{ use super::{ structs::{PSQLPyConnection, PoolConnection, SingleConnection}, - traits::{Connection, Transaction}, + traits::{CloseTransaction, Connection, Cursor, StartTransaction, Transaction}, }; impl Transaction for T where T: Connection, { - async fn start( - &self, + async fn _start_transaction( + &mut self, isolation_level: Option, read_variant: Option, deferrable: Option, @@ -35,7 +35,7 @@ where Ok(()) } - async fn commit(&self) -> PSQLPyResult<()> { + async fn _commit(&self) -> PSQLPyResult<()> { self.batch_execute("COMMIT;").await.map_err(|err| { RustPSQLDriverError::TransactionCommitError(format!( "Cannot execute COMMIT statement, error - {err}" @@ -44,7 +44,7 @@ where Ok(()) } - async fn rollback(&self) -> PSQLPyResult<()> { + async fn _rollback(&self) -> PSQLPyResult<()> { self.batch_execute("ROLLBACK;").await.map_err(|err| { RustPSQLDriverError::TransactionRollbackError(format!( "Cannot execute ROLLBACK statement, error - {err}" @@ -105,41 +105,37 @@ impl Connection for SingleConnection { } } -// impl Transaction for SingleConnection { -// async fn start( -// &self, -// isolation_level: Option, -// read_variant: Option, -// deferrable: Option, -// ) -> PSQLPyResult<()> { -// let start_qs = self.build_start_qs(isolation_level, read_variant, deferrable); -// self.batch_execute(start_qs.as_str()).await.map_err(|err| { -// RustPSQLDriverError::TransactionBeginError( -// format!("Cannot start transaction due to - {err}").into(), -// ) -// })?; - -// Ok(()) -// } - -// async fn commit(&self) -> PSQLPyResult<()> { -// self.batch_execute("COMMIT;").await.map_err(|err| { -// RustPSQLDriverError::TransactionCommitError(format!( -// "Cannot execute COMMIT statement, error - {err}" -// )) -// })?; -// Ok(()) -// } - -// async fn rollback(&self) -> PSQLPyResult<()> { -// self.batch_execute("ROLLBACK;").await.map_err(|err| { -// RustPSQLDriverError::TransactionRollbackError(format!( -// "Cannot execute ROLLBACK statement, error - {err}" -// )) -// })?; -// Ok(()) -// } -// } +impl StartTransaction for SingleConnection { + async fn start_transaction( + &mut self, + isolation_level: Option, + read_variant: Option, + deferrable: Option, + ) -> PSQLPyResult<()> { + let res = self + ._start_transaction(isolation_level, read_variant, deferrable) + .await?; + self.in_transaction = true; + + Ok(res) + } +} + +impl CloseTransaction for SingleConnection { + async fn commit(&mut self) -> PSQLPyResult<()> { + let res = self._commit().await?; + self.in_transaction = false; + + Ok(res) + } + + async fn rollback(&mut self) -> PSQLPyResult<()> { + let res = self._rollback().await?; + self.in_transaction = false; + + Ok(res) + } +} impl Connection for PoolConnection { async fn prepare(&self, query: &str, prepared: bool) -> PSQLPyResult { @@ -193,41 +189,34 @@ impl Connection for PoolConnection { } } -// impl Transaction for PoolConnection { -// async fn start( -// &self, -// isolation_level: Option, -// read_variant: Option, -// deferrable: Option, -// ) -> PSQLPyResult<()> { -// let start_qs = self.build_start_qs(isolation_level, read_variant, deferrable); -// self.batch_execute(start_qs.as_str()).await.map_err(|err| { -// RustPSQLDriverError::TransactionBeginError( -// format!("Cannot start transaction due to - {err}").into(), -// ) -// })?; - -// Ok(()) -// } - -// async fn commit(&self) -> PSQLPyResult<()> { -// self.batch_execute("COMMIT;").await.map_err(|err| { -// RustPSQLDriverError::TransactionCommitError(format!( -// "Cannot execute COMMIT statement, error - {err}" -// )) -// })?; -// Ok(()) -// } - -// async fn rollback(&self) -> PSQLPyResult<()> { -// self.batch_execute("ROLLBACK;").await.map_err(|err| { -// RustPSQLDriverError::TransactionRollbackError(format!( -// "Cannot execute ROLLBACK statement, error - {err}" -// )) -// })?; -// Ok(()) -// } -// } +impl StartTransaction for PoolConnection { + async fn start_transaction( + &mut self, + isolation_level: Option, + read_variant: Option, + deferrable: Option, + ) -> PSQLPyResult<()> { + self.in_transaction = true; + self._start_transaction(isolation_level, read_variant, deferrable) + .await + } +} + +impl CloseTransaction for PoolConnection { + async fn commit(&mut self) -> PSQLPyResult<()> { + let res = self._commit().await?; + self.in_transaction = false; + + Ok(res) + } + + async fn rollback(&mut self) -> PSQLPyResult<()> { + let res = self._rollback().await?; + self.in_transaction = false; + + Ok(res) + } +} impl Connection for PSQLPyConnection { async fn prepare(&self, query: &str, prepared: bool) -> PSQLPyResult { @@ -293,37 +282,81 @@ impl Connection for PSQLPyConnection { } } -// impl Transaction for PSQLPyConnection { -// async fn start( -// &self, -// isolation_level: Option, -// read_variant: Option, -// deferrable: Option, -// ) -> PSQLPyResult<()> { -// match self { -// PSQLPyConnection::PoolConn(p_conn) => p_conn.start(isolation_level, read_variant, deferrable).await, -// PSQLPyConnection::SingleConnection(s_conn) => s_conn.start(isolation_level, read_variant, deferrable).await, -// } -// } - -// async fn commit(&self) -> PSQLPyResult<()> { -// self.batch_execute("COMMIT;").await.map_err(|err| { -// RustPSQLDriverError::TransactionCommitError(format!( -// "Cannot execute COMMIT statement, error - {err}" -// )) -// })?; -// Ok(()) -// } - -// async fn rollback(&self) -> PSQLPyResult<()> { -// self.batch_execute("ROLLBACK;").await.map_err(|err| { -// RustPSQLDriverError::TransactionRollbackError(format!( -// "Cannot execute ROLLBACK statement, error - {err}" -// )) -// })?; -// Ok(()) -// } -// } +impl StartTransaction for PSQLPyConnection { + async fn start_transaction( + &mut self, + isolation_level: Option, + read_variant: Option, + deferrable: Option, + ) -> PSQLPyResult<()> { + match self { + PSQLPyConnection::PoolConn(p_conn) => { + p_conn + .start_transaction(isolation_level, read_variant, deferrable) + .await + } + PSQLPyConnection::SingleConnection(s_conn) => { + s_conn + .start_transaction(isolation_level, read_variant, deferrable) + .await + } + } + } +} + +impl CloseTransaction for PSQLPyConnection { + async fn commit(&mut self) -> PSQLPyResult<()> { + match self { + PSQLPyConnection::PoolConn(p_conn) => p_conn.commit().await, + PSQLPyConnection::SingleConnection(s_conn) => s_conn.commit().await, + } + } + + async fn rollback(&mut self) -> PSQLPyResult<()> { + match self { + PSQLPyConnection::PoolConn(p_conn) => p_conn.rollback().await, + PSQLPyConnection::SingleConnection(s_conn) => s_conn.rollback().await, + } + } +} + +impl Cursor for PSQLPyConnection { + async fn start_cursor( + &mut self, + cursor_name: &str, + scroll: &Option, + querystring: String, + prepared: &Option, + parameters: Option>, + ) -> PSQLPyResult<()> { + let cursor_qs = self.build_cursor_start_qs(cursor_name, scroll, &querystring); + self.execute(cursor_qs, parameters, *prepared) + .await + .map_err(|err| { + RustPSQLDriverError::CursorStartError(format!("Cannot start cursor due to {err}")) + })?; + match self { + PSQLPyConnection::PoolConn(conn) => conn.in_cursor = true, + PSQLPyConnection::SingleConnection(conn) => conn.in_cursor = true, + } + Ok(()) + } + + async fn close_cursor(&mut self, cursor_name: &str) -> PSQLPyResult<()> { + self.execute( + format!("CLOSE {cursor_name}"), + Option::default(), + Some(false), + ) + .await?; + + match self { + PSQLPyConnection::PoolConn(conn) => conn.in_cursor = false, + PSQLPyConnection::SingleConnection(conn) => conn.in_cursor = false, + } + Ok(()) + } +} impl PSQLPyConnection { pub async fn execute( @@ -337,23 +370,24 @@ impl PSQLPyConnection { .await?; let prepared = prepared.unwrap_or(true); - let result = match prepared { - true => self - .query(statement.statement_query()?, &statement.params()) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - false => self - .query_typed(statement.raw_query(), &statement.params_typed()) - .await - .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))?, + true => { + self.query(statement.statement_query()?, &statement.params()) + .await + } + false => { + self.query_typed(statement.raw_query(), &statement.params_typed()) + .await + } }; - Ok(PSQLDriverPyQueryResult::new(result)) + let return_result = result.map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot execute query, error - {err}" + )) + })?; + + Ok(PSQLDriverPyQueryResult::new(return_result)) } pub async fn execute_many( diff --git a/src/connection/structs.rs b/src/connection/structs.rs index 9e713bfd..ccfd101d 100644 --- a/src/connection/structs.rs +++ b/src/connection/structs.rs @@ -1,12 +1,41 @@ +use std::sync::Arc; + use deadpool_postgres::Object; -use tokio_postgres::Client; +use tokio_postgres::{Client, Config}; pub struct PoolConnection { pub connection: Object, + pub in_transaction: bool, + pub in_cursor: bool, + pub pg_config: Arc, } +impl PoolConnection { + pub fn new(connection: Object, pg_config: Arc) -> Self { + Self { + connection, + in_transaction: false, + in_cursor: false, + pg_config, + } + } +} pub struct SingleConnection { pub connection: Client, + pub in_transaction: bool, + pub in_cursor: bool, + pub pg_config: Arc, +} + +impl SingleConnection { + pub fn new(connection: Client, pg_config: Arc) -> Self { + Self { + connection, + in_transaction: false, + in_cursor: false, + pg_config, + } + } } pub enum PSQLPyConnection { diff --git a/src/connection/traits.rs b/src/connection/traits.rs index 9428eb70..8e868a06 100644 --- a/src/connection/traits.rs +++ b/src/connection/traits.rs @@ -1,4 +1,5 @@ use postgres_types::{ToSql, Type}; +use pyo3::PyAny; use tokio_postgres::{Row, Statement, ToStatement}; use crate::exceptions::rust_errors::PSQLPyResult; @@ -74,14 +75,65 @@ pub trait Transaction { querystring } - fn start( - &self, + fn _start_transaction( + &mut self, + isolation_level: Option, + read_variant: Option, + deferrable: Option, + ) -> impl std::future::Future>; + + fn _commit(&self) -> impl std::future::Future>; + + fn _rollback(&self) -> impl std::future::Future>; +} + +pub trait StartTransaction: Transaction { + fn start_transaction( + &mut self, isolation_level: Option, read_variant: Option, deferrable: Option, ) -> impl std::future::Future>; +} - fn commit(&self) -> impl std::future::Future>; +pub trait CloseTransaction: StartTransaction { + fn commit(&mut self) -> impl std::future::Future>; - fn rollback(&self) -> impl std::future::Future>; + fn rollback(&mut self) -> impl std::future::Future>; +} + +pub trait Cursor { + fn build_cursor_start_qs( + &self, + cursor_name: &str, + scroll: &Option, + querystring: &str, + ) -> String { + let mut cursor_init_query = format!("DECLARE {cursor_name}"); + if let Some(scroll) = scroll { + if *scroll { + cursor_init_query.push_str(" SCROLL"); + } else { + cursor_init_query.push_str(" NO SCROLL"); + } + } + + cursor_init_query.push_str(format!(" CURSOR FOR {querystring}").as_str()); + + cursor_init_query + } + + fn start_cursor( + &mut self, + cursor_name: &str, + scroll: &Option, + querystring: String, + prepared: &Option, + parameters: Option>, + ) -> impl std::future::Future>; + + fn close_cursor( + &mut self, + cursor_name: &str, + ) -> impl std::future::Future>; } diff --git a/src/driver/common.rs b/src/driver/common.rs new file mode 100644 index 00000000..528fed84 --- /dev/null +++ b/src/driver/common.rs @@ -0,0 +1,94 @@ +use pyo3::prelude::*; +use tokio_postgres::config::Host; + +use std::net::IpAddr; + +use super::{connection::Connection, cursor::Cursor, transaction::Transaction}; + +macro_rules! impl_config_py_methods { + ($name:ident) => { + #[pymethods] + impl $name { + #[getter] + fn conn_dbname(&self) -> Option<&str> { + self.pg_config.get_dbname() + } + + #[getter] + fn user(&self) -> Option<&str> { + self.pg_config.get_user() + } + + #[getter] + fn host_addrs(&self) -> Vec { + let mut host_addrs_vec = vec![]; + + let host_addrs = self.pg_config.get_hostaddrs(); + for ip_addr in host_addrs { + match ip_addr { + IpAddr::V4(ipv4) => { + host_addrs_vec.push(ipv4.to_string()); + } + IpAddr::V6(ipv6) => { + host_addrs_vec.push(ipv6.to_string()); + } + } + } + + host_addrs_vec + } + + #[cfg(unix)] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + Host::Unix(host) => { + hosts_vec.push(host.display().to_string()); + } + } + } + + hosts_vec + } + + #[cfg(not(unix))] + #[getter] + fn hosts(&self) -> Vec { + let mut hosts_vec = vec![]; + + let hosts = self.pg_config.get_hosts(); + for host in hosts { + match host { + Host::Tcp(host) => { + hosts_vec.push(host.to_string()); + } + _ => unreachable!(), + } + } + + hosts_vec + } + + #[getter] + fn ports(&self) -> Vec<&u16> { + return self.pg_config.get_ports().iter().collect::>(); + } + + #[getter] + fn options(&self) -> Option<&str> { + return self.pg_config.get_options(); + } + } + }; +} + +impl_config_py_methods!(Transaction); +impl_config_py_methods!(Connection); +impl_config_py_methods!(Cursor); diff --git a/src/driver/common_options.rs b/src/driver/common_options.rs deleted file mode 100644 index a76d37dd..00000000 --- a/src/driver/common_options.rs +++ /dev/null @@ -1,141 +0,0 @@ -use std::time::Duration; - -use deadpool_postgres::RecyclingMethod; -use pyo3::{pyclass, pymethods}; - -#[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] -pub enum ConnRecyclingMethod { - Fast, - Verified, - Clean, -} - -impl ConnRecyclingMethod { - #[must_use] - pub fn to_internal(&self) -> RecyclingMethod { - match self { - ConnRecyclingMethod::Fast => RecyclingMethod::Fast, - ConnRecyclingMethod::Verified => RecyclingMethod::Verified, - ConnRecyclingMethod::Clean => RecyclingMethod::Clean, - } - } -} - -#[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] -pub enum LoadBalanceHosts { - /// Make connection attempts to hosts in the order provided. - Disable, - /// Make connection attempts to hosts in a random order. - Random, -} - -impl LoadBalanceHosts { - #[must_use] - pub fn to_internal(&self) -> tokio_postgres::config::LoadBalanceHosts { - match self { - LoadBalanceHosts::Disable => tokio_postgres::config::LoadBalanceHosts::Disable, - LoadBalanceHosts::Random => tokio_postgres::config::LoadBalanceHosts::Random, - } - } -} - -#[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] -pub enum TargetSessionAttrs { - /// No special properties are required. - Any, - /// The session must allow writes. - ReadWrite, - /// The session allow only reads. - ReadOnly, -} - -impl TargetSessionAttrs { - #[must_use] - pub fn to_internal(&self) -> tokio_postgres::config::TargetSessionAttrs { - match self { - TargetSessionAttrs::Any => tokio_postgres::config::TargetSessionAttrs::Any, - TargetSessionAttrs::ReadWrite => tokio_postgres::config::TargetSessionAttrs::ReadWrite, - TargetSessionAttrs::ReadOnly => tokio_postgres::config::TargetSessionAttrs::ReadOnly, - } - } -} - -#[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq, Debug)] -pub enum SslMode { - /// Do not use TLS. - Disable, - /// Pay the overhead of encryption if the server insists on it. - Allow, - /// Attempt to connect with TLS but allow sessions without. - Prefer, - /// Require the use of TLS. - Require, - /// I want my data encrypted, - /// and I accept the overhead. - /// I want to be sure that I connect to a server that I trust. - VerifyCa, - /// I want my data encrypted, - /// and I accept the overhead. - /// I want to be sure that I connect to a server I trust, - /// and that it's the one I specify. - VerifyFull, -} - -impl SslMode { - #[must_use] - pub fn to_internal(&self) -> tokio_postgres::config::SslMode { - match self { - SslMode::Disable => tokio_postgres::config::SslMode::Disable, - SslMode::Allow => tokio_postgres::config::SslMode::Allow, - SslMode::Prefer => tokio_postgres::config::SslMode::Prefer, - SslMode::Require => tokio_postgres::config::SslMode::Require, - SslMode::VerifyCa => tokio_postgres::config::SslMode::VerifyCa, - SslMode::VerifyFull => tokio_postgres::config::SslMode::VerifyFull, - } - } -} - -#[pyclass] -#[derive(Clone, Copy)] -pub struct KeepaliveConfig { - pub idle: Duration, - pub interval: Option, - pub retries: Option, -} - -#[pymethods] -impl KeepaliveConfig { - #[new] - #[pyo3(signature = (idle, interval=None, retries=None))] - fn build_config(idle: u64, interval: Option, retries: Option) -> Self { - let interval_internal = interval.map(Duration::from_secs); - KeepaliveConfig { - idle: Duration::from_secs(idle), - interval: interval_internal, - retries, - } - } -} - -#[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] -pub enum CopyCommandFormat { - TEXT, - CSV, - BINARY, -} - -impl CopyCommandFormat { - #[must_use] - pub fn to_internal(&self) -> String { - match self { - CopyCommandFormat::TEXT => "text".into(), - CopyCommandFormat::CSV => "csv".into(), - CopyCommandFormat::BINARY => "binary".into(), - } - } -} diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 3b5a4ab5..9635a836 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -2,8 +2,9 @@ use bytes::BytesMut; use deadpool_postgres::Pool; use futures_util::pin_mut; use pyo3::{buffer::PyBuffer, pyclass, pyfunction, pymethods, Py, PyAny, PyErr, Python}; -use std::{collections::HashSet, net::IpAddr, sync::Arc}; -use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, Config}; use crate::{ connection::{ @@ -12,17 +13,12 @@ use crate::{ }, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, - options::{IsolationLevel, ReadVariant}, + options::{IsolationLevel, LoadBalanceHosts, ReadVariant, SslMode, TargetSessionAttrs}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, runtime::tokio_runtime, }; -use super::{ - common_options::{LoadBalanceHosts, SslMode, TargetSessionAttrs}, - connection_pool::connect_pool, - cursor::Cursor, - transaction::Transaction, -}; +use super::{connection_pool::connect_pool, cursor::Cursor, transaction::Transaction}; /// Make new connection pool. /// @@ -121,15 +117,15 @@ pub async fn connect( #[pyclass(subclass)] #[derive(Clone)] pub struct Connection { - db_client: Option>, + db_client: Option>>, db_pool: Option, - pg_config: Arc, + pub pg_config: Arc, } impl Connection { #[must_use] pub fn new( - db_client: Option>, + db_client: Option>>, db_pool: Option, pg_config: Arc, ) -> Self { @@ -141,7 +137,7 @@ impl Connection { } #[must_use] - pub fn db_client(&self) -> Option> { + pub fn db_client(&self) -> Option>> { self.db_client.clone() } @@ -159,87 +155,14 @@ impl Default for Connection { #[pymethods] impl Connection { - #[getter] - fn conn_dbname(&self) -> Option<&str> { - self.pg_config.get_dbname() - } - - #[getter] - fn user(&self) -> Option<&str> { - self.pg_config.get_user() - } - - #[getter] - fn host_addrs(&self) -> Vec { - let mut host_addrs_vec = vec![]; - - let host_addrs = self.pg_config.get_hostaddrs(); - for ip_addr in host_addrs { - match ip_addr { - IpAddr::V4(ipv4) => { - host_addrs_vec.push(ipv4.to_string()); - } - IpAddr::V6(ipv6) => { - host_addrs_vec.push(ipv6.to_string()); - } - } - } - - host_addrs_vec - } - - #[cfg(unix)] - #[getter] - fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - Host::Unix(host) => { - hosts_vec.push(host.display().to_string()); - } - } - } - - hosts_vec - } - - #[cfg(not(unix))] - #[getter] - fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - _ => unreachable!(), - } - } - - hosts_vec - } - - #[getter] - fn ports(&self) -> Vec<&u16> { - return self.pg_config.get_ports().iter().collect::>(); - } - - #[getter] - fn options(&self) -> Option<&str> { - return self.pg_config.get_options(); - } - async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { - let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { + let (db_client, db_pool, pg_config) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); - (self_.db_client.clone(), self_.db_pool.clone()) + ( + self_.db_client.clone(), + self_.db_pool.clone(), + self_.pg_config.clone(), + ) }); if db_client.is_some() { @@ -247,16 +170,16 @@ impl Connection { } if let Some(db_pool) = db_pool { - let db_connection = tokio_runtime() + let connection = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) }) .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(PSQLPyConnection::PoolConn(PoolConnection { - connection: db_connection, - }))); + self_.db_client = Some(Arc::new(RwLock::new(PSQLPyConnection::PoolConn( + PoolConnection::new(connection, pg_config), + )))); }); return Ok(self_); } @@ -310,7 +233,8 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - let res = db_client.execute(querystring, parameters, prepared).await; + let read_conn_g = db_client.read().await; + let res = read_conn_g.execute(querystring, parameters, prepared).await; return res; } @@ -333,7 +257,8 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client.batch_execute(&querystring).await; + let read_conn_g = db_client.read().await; + return read_conn_g.batch_execute(&querystring).await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -359,7 +284,8 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client + let read_conn_g = db_client.read().await; + return read_conn_g .execute_many(querystring, parameters, prepared) .await; } @@ -385,7 +311,8 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client.execute(querystring, parameters, prepared).await; + let read_conn_g = db_client.read().await; + return read_conn_g.execute(querystring, parameters, prepared).await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -415,7 +342,10 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client.fetch_row(querystring, parameters, prepared).await; + let read_conn_g = db_client.read().await; + return read_conn_g + .fetch_row(querystring, parameters, prepared) + .await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -442,7 +372,10 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client.fetch_val(querystring, parameters, prepared).await; + let read_conn_g = db_client.read().await; + return read_conn_g + .fetch_val(querystring, parameters, prepared) + .await; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -465,14 +398,11 @@ impl Connection { ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Transaction::new( - db_client.clone(), + Some(db_client.clone()), self.pg_config.clone(), - false, - false, isolation_level, read_variant, deferrable, - HashSet::new(), )); } @@ -504,7 +434,6 @@ impl Connection { self.pg_config.clone(), querystring, parameters, - "cur_name".into(), fetch_number.unwrap_or(10), scroll, prepared, @@ -576,7 +505,8 @@ impl Connection { )) })?; - let sink = db_client.copy_in(©_qs).await?; + let read_conn_g = db_client.read().await; + let sink = read_conn_g.copy_in(©_qs).await?; let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); pin_mut!(writer); writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index c66be3da..16e1fe90 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -6,12 +6,15 @@ use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use postgres_types::Type; use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; use std::sync::Arc; +use tokio::sync::RwLock; use tokio_postgres::Config; -use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, +}; use super::{ - common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, connection::Connection, listener::core::Listener, utils::{build_connection_config, build_manager, build_tls}, @@ -245,9 +248,9 @@ impl ConnectionPool { let connection = self.pool.get().await?; Ok(Connection::new( - Some(Arc::new(PSQLPyConnection::PoolConn(PoolConnection { - connection, - }))), + Some(Arc::new(RwLock::new(PSQLPyConnection::PoolConn( + PoolConnection::new(connection, self.pg_config.clone()), + )))), None, self.pg_config.clone(), )) @@ -422,9 +425,9 @@ impl ConnectionPool { .await??; Ok(Connection::new( - Some(Arc::new(PSQLPyConnection::PoolConn(PoolConnection { - connection, - }))), + Some(Arc::new(RwLock::new(PSQLPyConnection::PoolConn( + PoolConnection::new(connection, pg_config.clone()), + )))), None, pg_config, )) diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index 0cd7432b..dcecd761 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -3,10 +3,12 @@ use std::{net::IpAddr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use pyo3::{pyclass, pymethods, Py, Python}; -use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, +}; use super::{ - common_options, connection_pool::ConnectionPool, utils::{build_manager, build_tls}, }; @@ -17,7 +19,7 @@ pub struct ConnectionPoolBuilder { max_db_pool_size: Option, conn_recycling_method: Option, ca_file: Option, - ssl_mode: Option, + ssl_mode: Option, prepare: Option, } @@ -104,7 +106,7 @@ impl ConnectionPoolBuilder { /// Set connection recycling method. fn conn_recycling_method( self_: Py, - conn_recycling_method: super::common_options::ConnRecyclingMethod, + conn_recycling_method: ConnRecyclingMethod, ) -> Py { Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); @@ -171,7 +173,7 @@ impl ConnectionPoolBuilder { /// /// Defaults to `prefer`. #[must_use] - pub fn ssl_mode(self_: Py, ssl_mode: crate::driver::common_options::SslMode) -> Py { + pub fn ssl_mode(self_: Py, ssl_mode: SslMode) -> Py { Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); self_.ssl_mode = Some(ssl_mode); @@ -259,7 +261,7 @@ impl ConnectionPoolBuilder { #[must_use] pub fn target_session_attrs( self_: Py, - target_session_attrs: super::common_options::TargetSessionAttrs, + target_session_attrs: TargetSessionAttrs, ) -> Py { Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); @@ -274,10 +276,7 @@ impl ConnectionPoolBuilder { /// /// Defaults to `disable`. #[must_use] - pub fn load_balance_hosts( - self_: Py, - load_balance_hosts: super::common_options::LoadBalanceHosts, - ) -> Py { + pub fn load_balance_hosts(self_: Py, load_balance_hosts: LoadBalanceHosts) -> Py { Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); self_ diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 7229c6ee..a12d2bfa 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -1,209 +1,84 @@ -use std::{net::IpAddr, sync::Arc}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; use pyo3::{ exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, }; -use tokio_postgres::{config::Host, Config}; +use tokio::sync::RwLock; +use tokio_postgres::Config; use crate::{ - connection::structs::PSQLPyConnection, + connection::{structs::PSQLPyConnection, traits::Cursor as _}, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, query_result::PSQLDriverPyQueryResult, runtime::rustdriver_future, }; -/// Additional implementation for the `Object` type. -#[allow(clippy::ref_option)] -trait CursorObjectTrait { - async fn cursor_start( - &self, - cursor_name: &str, - scroll: &Option, - querystring: &str, - prepared: &Option, - parameters: &Option>, - ) -> PSQLPyResult<()>; - - async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> PSQLPyResult<()>; -} - -impl CursorObjectTrait for PSQLPyConnection { - /// Start the cursor. - /// - /// Execute `DECLARE` command with parameters. - /// - /// # Errors - /// May return Err Result if cannot execute querystring. - #[allow(clippy::ref_option)] - async fn cursor_start( - &self, - cursor_name: &str, - scroll: &Option, - querystring: &str, - prepared: &Option, - parameters: &Option>, - ) -> PSQLPyResult<()> { - let mut cursor_init_query = format!("DECLARE {cursor_name}"); - if let Some(scroll) = scroll { - if *scroll { - cursor_init_query.push_str(" SCROLL"); - } else { - cursor_init_query.push_str(" NO SCROLL"); - } - } - - cursor_init_query.push_str(format!(" CURSOR FOR {querystring}").as_str()); - - self.execute(cursor_init_query, parameters.clone(), *prepared) - .await - .map_err(|err| { - RustPSQLDriverError::CursorStartError(format!("Cannot start cursor, error - {err}")) - })?; - - Ok(()) - } - - /// Close the cursor. - /// - /// Execute `CLOSE` command. - /// - /// # Errors - /// May return Err Result if cannot execute querystring. - async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> PSQLPyResult<()> { - if *closed { - return Err(RustPSQLDriverError::CursorCloseError( - "Cursor is already closed".into(), - )); - } - - self.execute( - format!("CLOSE {cursor_name}"), - Option::default(), - Some(false), - ) - .await?; +static NEXT_CUR_ID: AtomicUsize = AtomicUsize::new(0); - Ok(()) - } +fn next_cursor_name() -> String { + format!("cur{}", NEXT_CUR_ID.fetch_add(1, Ordering::SeqCst),) } #[pyclass(subclass)] pub struct Cursor { - db_transaction: Option>, - pg_config: Arc, + conn: Option>>, + pub pg_config: Arc, querystring: String, parameters: Option>, - cursor_name: String, + cursor_name: Option, fetch_number: usize, scroll: Option, prepared: Option, - is_started: bool, - closed: bool, } impl Cursor { - #[must_use] pub fn new( - db_transaction: Arc, + conn: Arc>, pg_config: Arc, querystring: String, parameters: Option>, - cursor_name: String, fetch_number: usize, scroll: Option, prepared: Option, ) -> Self { Cursor { - db_transaction: Some(db_transaction), + conn: Some(conn), pg_config, querystring, parameters, - cursor_name, + cursor_name: None, fetch_number, scroll, prepared, - is_started: false, - closed: false, } } -} -#[pymethods] -impl Cursor { - #[getter] - fn conn_dbname(&self) -> Option<&str> { - self.pg_config.get_dbname() - } - - #[getter] - fn user(&self) -> Option<&str> { - self.pg_config.get_user() - } + async fn execute(&self, querystring: &str) -> PSQLPyResult { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + let read_conn_g = conn.read().await; - #[getter] - fn host_addrs(&self) -> Vec { - let mut host_addrs_vec = vec![]; - - let host_addrs = self.pg_config.get_hostaddrs(); - for ip_addr in host_addrs { - match ip_addr { - IpAddr::V4(ipv4) => { - host_addrs_vec.push(ipv4.to_string()); - } - IpAddr::V6(ipv6) => { - host_addrs_vec.push(ipv6.to_string()); - } - } - } - - host_addrs_vec - } - - #[cfg(unix)] - #[getter] - fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - Host::Unix(host) => { - hosts_vec.push(host.display().to_string()); - } - } - } - - hosts_vec - } - - #[cfg(not(unix))] - #[getter] - fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - _ => unreachable!(), - } - } - - hosts_vec - } + let result = read_conn_g + .execute(querystring.to_string(), None, Some(false)) + .await + .map_err(|err| { + RustPSQLDriverError::CursorFetchError(format!( + "Cannot fetch data from cursor, error - {err}" + )) + })?; - #[getter] - fn ports(&self) -> Vec<&u16> { - return self.pg_config.get_ports().iter().collect::>(); + Ok(result) } +} +#[pymethods] +impl Cursor { #[getter] - fn cursor_name(&self) -> String { + fn cursor_name(&self) -> Option { return self.cursor_name.clone(); } @@ -232,471 +107,237 @@ impl Cursor { } async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { - let (db_transaction, cursor_name, scroll, querystring, prepared, parameters) = - Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - ( - self_.db_transaction.clone(), - self_.cursor_name.clone(), - self_.scroll, - self_.querystring.clone(), - self_.prepared, - self_.parameters.clone(), - ) - }); - - if let Some(db_transaction) = db_transaction { - db_transaction - .cursor_start(&cursor_name, &scroll, &querystring, &prepared, ¶meters) - .await?; - Python::with_gil(|gil| { - let mut self_ = slf.borrow_mut(gil); - self_.is_started = true; - }); - return Ok(slf); - } - Err(RustPSQLDriverError::CursorClosedError) + let cursor_name = next_cursor_name(); + + let (conn, scroll, querystring, prepared, parameters) = Python::with_gil(|gil| { + let mut self_ = slf.borrow_mut(gil); + self_.cursor_name = Some(cursor_name.clone()); + ( + self_.conn.clone(), + self_.scroll, + self_.querystring.clone(), + self_.prepared, + self_.parameters.clone(), + ) + }); + + let Some(conn) = conn else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + let mut write_conn_g = conn.write().await; + + write_conn_g + .start_cursor( + &cursor_name, + &scroll, + querystring.clone(), + &prepared, + parameters.clone(), + ) + .await?; + + Ok(slf) } #[allow(clippy::needless_pass_by_value)] async fn __aexit__<'a>( - slf: Py, + &mut self, _exception_type: Py, exception: Py, _traceback: Py, ) -> PSQLPyResult<()> { - let (db_transaction, closed, cursor_name, is_exception_none, py_err) = - pyo3::Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - ( - self_.db_transaction.clone(), - self_.closed, - self_.cursor_name.clone(), - exception.is_none(gil), - PyErr::from_value(exception.into_bound(gil)), - ) - }); - - if let Some(db_transaction) = db_transaction { - db_transaction - .cursor_close(&closed, &cursor_name) - .await - .map_err(|err| { - RustPSQLDriverError::CursorCloseError(format!( - "Cannot close the cursor, error - {err}" - )) - })?; - pyo3::Python::with_gil(|gil| { - let mut self_ = slf.borrow_mut(gil); - std::mem::take(&mut self_.db_transaction); - }); - if !is_exception_none { - return Err(RustPSQLDriverError::RustPyError(py_err)); - } - return Ok(()); + self.close().await?; + + let (is_exc_none, py_err) = pyo3::Python::with_gil(|gil| { + ( + exception.is_none(gil), + PyErr::from_value(exception.into_bound(gil)), + ) + }); + + if !is_exc_none { + return Err(RustPSQLDriverError::RustPyError(py_err)); } - Err(RustPSQLDriverError::CursorClosedError) + Ok(()) } - /// Return next result from the SQL statement. - /// - /// Execute FETCH FROM - /// - /// # Errors - /// May return Err Result if can't execute querystring. fn __anext__(&self) -> PSQLPyResult> { - let db_transaction = self.db_transaction.clone(); + let conn = self.conn.clone(); let fetch_number = self.fetch_number; - let cursor_name = self.cursor_name.clone(); + let Some(cursor_name) = self.cursor_name.clone() else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + let py_future = Python::with_gil(move |gil| { rustdriver_future(gil, async move { - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute( - format!("FETCH {fetch_number} FROM {cursor_name}"), - None, - Some(false), - ) - .await?; - - if result.is_empty() { - return Err(PyStopAsyncIteration::new_err( - "Iteration is over, no more results in cursor", - ) - .into()); - }; - - return Ok(result); - } - Err(RustPSQLDriverError::CursorClosedError) + let Some(conn) = conn else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + + let read_conn_g = conn.read().await; + let result = read_conn_g + .execute( + format!("FETCH {fetch_number} FROM {cursor_name}"), + None, + Some(false), + ) + .await?; + + if result.is_empty() { + return Err(PyStopAsyncIteration::new_err( + "Iteration is over, no more results in cursor", + ) + .into()); + }; + Ok(result) }) }); Ok(Some(py_future?)) } - /// Start the cursor - /// - /// # Errors - /// May return Err Result - /// if cannot execute querystring for cursor declaration. pub async fn start(&mut self) -> PSQLPyResult<()> { - let db_transaction_arc = self.db_transaction.clone(); - - if let Some(db_transaction) = db_transaction_arc { - db_transaction - .cursor_start( - &self.cursor_name, - &self.scroll, - &self.querystring, - &self.prepared, - &self.parameters, - ) - .await?; - - self.is_started = true; + if self.cursor_name.is_some() { return Ok(()); } - Err(RustPSQLDriverError::CursorClosedError) - } + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + let mut write_conn_g = conn.write().await; - /// Close the cursor. - /// - /// It executes CLOSE command to close cursor in the transaction. - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn close(&mut self) -> PSQLPyResult<()> { - let db_transaction_arc = self.db_transaction.clone(); + let cursor_name = next_cursor_name(); - if let Some(db_transaction) = db_transaction_arc { - db_transaction - .cursor_close(&self.closed, &self.cursor_name) - .await?; + write_conn_g + .start_cursor( + &cursor_name, + &self.scroll, + self.querystring.clone(), + &self.prepared, + self.parameters.clone(), + ) + .await?; - self.closed = true; - std::mem::take(&mut self.db_transaction); - return Ok(()); - } + self.cursor_name = Some(cursor_name); - Err(RustPSQLDriverError::CursorClosedError) + Ok(()) } - /// Fetch data from cursor. - /// - /// It's possible to specify fetch number. - /// - /// # Errors - /// May return Err Result if cannot execute query. - #[pyo3(signature = (fetch_number=None))] - pub async fn fetch<'a>( - slf: Py, - fetch_number: Option, - ) -> PSQLPyResult { - let (db_transaction, inner_fetch_number, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - ( - self_.db_transaction.clone(), - self_.fetch_number, - self_.cursor_name.clone(), - ) - }); - - if let Some(db_transaction) = db_transaction { - let fetch_number = match fetch_number { - Some(usize) => usize, - None => inner_fetch_number, + pub async fn close(&mut self) -> PSQLPyResult<()> { + if let Some(cursor_name) = &self.cursor_name { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::CursorClosedError); }; + let mut write_conn_g = conn.write().await; + write_conn_g.close_cursor(&cursor_name).await?; + self.cursor_name = None; + }; - let result = db_transaction - .execute( - format!("FETCH {fetch_number} FROM {cursor_name}"), - None, - Some(false), - ) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - - return Ok(result); - } + self.conn = None; - Err(RustPSQLDriverError::CursorClosedError) + Ok(()) } - /// Fetch row from cursor. - /// - /// Execute FETCH NEXT. - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_next<'a>(slf: Py) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute(format!("FETCH NEXT FROM {cursor_name}"), None, Some(false)) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + #[pyo3(signature = (fetch_number=None))] + pub async fn fetch( + &self, + fetch_number: Option, + ) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!( + "FETCH {} FROM {}", + fetch_number.unwrap_or(self.fetch_number), + cursor_name, + )) + .await } - /// Fetch previous from cursor. - /// - /// Execute FETCH PRIOR. - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_prior<'a>(slf: Py) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute(format!("FETCH PRIOR FROM {cursor_name}"), None, Some(false)) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + pub async fn fetch_next(&self) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!("FETCH NEXT FROM {cursor_name}")) + .await } - /// Fetch first row from cursor. - /// - /// Execute FETCH FIRST (same as ABSOLUTE 1) - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_first<'a>(slf: Py) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute(format!("FETCH FIRST FROM {cursor_name}"), None, Some(false)) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + pub async fn fetch_prior(&self) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!("FETCH PRIOR FROM {cursor_name}")) + .await } - /// Fetch last row from cursor. - /// - /// Execute FETCH LAST (same as ABSOLUTE -1) - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_last<'a>(slf: Py) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute(format!("FETCH LAST FROM {cursor_name}"), None, Some(false)) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } + pub async fn fetch_first(&self) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!("FETCH FIRST FROM {cursor_name}")) + .await + } - Err(RustPSQLDriverError::CursorClosedError) + pub async fn fetch_last(&self) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!("FETCH LAST FROM {cursor_name}")) + .await } - /// Fetch absolute row from cursor. - /// - /// Execute FETCH ABSOLUTE. - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_absolute<'a>( - slf: Py, + pub async fn fetch_absolute( + &self, absolute_number: i64, ) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute( - format!("FETCH ABSOLUTE {absolute_number} FROM {cursor_name}"), - None, - Some(false), - ) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!( + "FETCH ABSOLUTE {absolute_number} FROM {cursor_name}" + )) + .await } - /// Fetch absolute row from cursor. - /// - /// Execute FETCH ABSOLUTE. - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_relative<'a>( - slf: Py, + pub async fn fetch_relative( + &self, relative_number: i64, ) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute( - format!("FETCH RELATIVE {relative_number} FROM {cursor_name}"), - None, - Some(false), - ) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!( + "FETCH RELATIVE {relative_number} FROM {cursor_name}" + )) + .await } - /// Fetch forward all from cursor. - /// - /// Execute FORWARD ALL. - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_forward_all<'a>(slf: Py) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute( - format!("FETCH FORWARD ALL FROM {cursor_name}"), - None, - Some(false), - ) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + pub async fn fetch_forward_all(&self) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!("FETCH FORWARD ALL FROM {cursor_name}")) + .await } - /// Fetch backward from cursor. - /// - /// Execute BACKWARD . - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_backward<'a>( - slf: Py, + pub async fn fetch_backward( + &self, backward_count: i64, ) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute( - format!("FETCH BACKWARD {backward_count} FROM {cursor_name}",), - None, - Some(false), - ) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!( + "FETCH BACKWARD {backward_count} FROM {cursor_name}" + )) + .await } - /// Fetch backward from cursor. - /// - /// Execute BACKWARD . - /// - /// # Errors - /// May return Err Result if cannot execute query. - pub async fn fetch_backward_all<'a>(slf: Py) -> PSQLPyResult { - let (db_transaction, cursor_name) = Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - (self_.db_transaction.clone(), self_.cursor_name.clone()) - }); - - if let Some(db_transaction) = db_transaction { - let result = db_transaction - .execute( - format!("FETCH BACKWARD ALL FROM {cursor_name}"), - None, - Some(false), - ) - .await - .map_err(|err| { - RustPSQLDriverError::CursorFetchError(format!( - "Cannot fetch data from cursor, error - {err}" - )) - })?; - return Ok(result); - } - - Err(RustPSQLDriverError::CursorClosedError) + pub async fn fetch_backward_all(&self) -> PSQLPyResult { + let Some(cursor_name) = &self.cursor_name else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + self.execute(&format!("FETCH BACKWARD ALL FROM {cursor_name}")) + .await } } diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 7837478f..8ae57d22 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -17,11 +17,11 @@ use crate::{ traits::Connection as _, }, driver::{ - common_options::SslMode, connection::Connection, utils::{build_tls, is_coroutine_function, ConfiguredTLS}, }, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + options::SslMode, runtime::{rustdriver_future, tokio_runtime}, }; @@ -222,9 +222,9 @@ impl Listener { self.receiver = Some(Arc::new(RwLock::new(receiver))); self.connection = Connection::new( - Some(Arc::new(PSQLPyConnection::SingleConnection( - SingleConnection { connection: client }, - ))), + Some(Arc::new(RwLock::new(PSQLPyConnection::SingleConnection( + SingleConnection::new(client, self.pg_config.clone()), + )))), None, self.pg_config.clone(), ); @@ -355,8 +355,9 @@ async fn dispatch_callback( async fn execute_listen( is_listened: &Arc>, listen_query: &Arc>, - client: &Arc, + client: &Arc>, ) -> PSQLPyResult<()> { + let read_conn_g = client.read().await; let mut write_is_listened = is_listened.write().await; if !write_is_listened.eq(&true) { @@ -365,7 +366,7 @@ async fn execute_listen( String::from(read_listen_query.as_str()) }; - client.batch_execute(listen_q.as_str()).await?; + read_conn_g.batch_execute(listen_q.as_str()).await?; } *write_is_listened = true; diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 1cff9f57..ab1b149c 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -1,4 +1,4 @@ -pub mod common_options; +pub mod common; pub mod connection; pub mod connection_pool; pub mod connection_pool_builder; diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 50bbfb74..c779837f 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -1,17 +1,20 @@ +use std::sync::Arc; + use bytes::BytesMut; -use futures_util::{future, pin_mut}; +use futures::{future, pin_mut}; use pyo3::{ buffer::PyBuffer, - prelude::*, - pyclass, - types::{PyList, PyTuple}, + pyclass, pymethods, + types::{PyAnyMethods, PyList, PyTuple}, + Py, PyAny, PyErr, PyResult, }; -use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; +use tokio::sync::RwLock; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, Config}; use crate::{ connection::{ structs::PSQLPyConnection, - traits::{Connection as _, Transaction as _}, + traits::{CloseTransaction, Connection, StartTransaction as _}, }, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, @@ -20,136 +23,37 @@ use crate::{ }; use super::cursor::Cursor; -use std::{collections::HashSet, net::IpAddr, sync::Arc}; #[pyclass(subclass)] pub struct Transaction { - pub db_client: Option>, - pg_config: Arc, - is_started: bool, - is_done: bool, + pub conn: Option>>, + pub pg_config: Arc, isolation_level: Option, read_variant: Option, deferrable: Option, - - savepoints_map: HashSet, } impl Transaction { - #[allow(clippy::too_many_arguments)] - #[must_use] pub fn new( - db_client: Arc, + conn: Option>>, pg_config: Arc, - is_started: bool, - is_done: bool, isolation_level: Option, read_variant: Option, deferrable: Option, - savepoints_map: HashSet, ) -> Self { Self { - db_client: Some(db_client), + conn, pg_config, - is_started, - is_done, isolation_level, read_variant, deferrable, - savepoints_map, } } - - fn check_is_transaction_ready(&self) -> PSQLPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::TransactionBeginError( - "Transaction is not started, please call begin() on transaction".into(), - )); - } - if self.is_done { - return Err(RustPSQLDriverError::TransactionBeginError( - "Transaction is already committed or rolled back".into(), - )); - } - Ok(()) - } } #[pymethods] impl Transaction { - #[getter] - fn conn_dbname(&self) -> Option<&str> { - self.pg_config.get_dbname() - } - - #[getter] - fn user(&self) -> Option<&str> { - self.pg_config.get_user() - } - - #[getter] - fn host_addrs(&self) -> Vec { - let mut host_addrs_vec = vec![]; - - let host_addrs = self.pg_config.get_hostaddrs(); - for ip_addr in host_addrs { - match ip_addr { - IpAddr::V4(ipv4) => { - host_addrs_vec.push(ipv4.to_string()); - } - IpAddr::V6(ipv6) => { - host_addrs_vec.push(ipv6.to_string()); - } - } - } - - host_addrs_vec - } - - #[cfg(unix)] - #[getter] - fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - Host::Unix(host) => { - hosts_vec.push(host.display().to_string()); - } - } - } - - hosts_vec - } - - #[cfg(not(unix))] - #[getter] - fn hosts(&self) -> Vec { - let mut hosts_vec = vec![]; - - let hosts = self.pg_config.get_hosts(); - for host in hosts { - match host { - Host::Tcp(host) => { - hosts_vec.push(host.to_string()); - } - _ => unreachable!(), - } - } - - hosts_vec - } - - #[getter] - fn ports(&self) -> Vec<&u16> { - return self.pg_config.get_ports().iter().collect::>(); - } - #[must_use] pub fn __aiter__(self_: Py) -> Py { self_ @@ -160,44 +64,25 @@ impl Transaction { } async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { - let (is_started, is_done, isolation_level, read_variant, deferrable, db_client) = - pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - ( - self_.is_started, - self_.is_done, - self_.isolation_level, - self_.read_variant, - self_.deferrable, - self_.db_client.clone(), - ) - }); - - if is_started { - return Err(RustPSQLDriverError::TransactionBeginError( - "Transaction is already started".into(), - )); - } - - if is_done { - return Err(RustPSQLDriverError::TransactionBeginError( - "Transaction is already committed or rolled back".into(), - )); - } - - if let Some(db_client) = db_client { - db_client - .start(isolation_level, read_variant, deferrable) - .await?; + let (isolation_level, read_variant, deferrable, conn) = pyo3::Python::with_gil(|gil| { + let self_ = self_.borrow(gil); + ( + self_.isolation_level, + self_.read_variant, + self_.deferrable, + self_.conn.clone(), + ) + }); - Python::with_gil(|gil| { - let mut self_ = self_.borrow_mut(gil); - self_.is_started = true; - }); - return Ok(self_); - } + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let mut write_conn_g = conn.write().await; + write_conn_g + .start_transaction(isolation_level, read_variant, deferrable) + .await?; - Err(RustPSQLDriverError::TransactionClosedError) + return Ok(self_); } #[allow(clippy::needless_pass_by_value)] @@ -207,456 +92,251 @@ impl Transaction { exception: Py, _traceback: Py, ) -> PSQLPyResult<()> { - let (is_transaction_ready, is_exception_none, py_err, db_client) = + let (conn, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { + let self_ = self_.borrow(gil); + ( + self_.conn.clone(), + exception.is_none(gil), + PyErr::from_value(exception.into_bound(gil)), + ) + }); + + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let mut write_conn_g = conn.write().await; + if is_exception_none { + write_conn_g.commit().await?; pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - ( - self_.check_is_transaction_ready(), - exception.is_none(gil), - PyErr::from_value(exception.into_bound(gil)), - self_.db_client.clone(), - ) + let mut self_ = self_.borrow_mut(gil); + self_.conn = None; }); - is_transaction_ready?; - - if let Some(db_client) = db_client { - let exit_result = if is_exception_none { - db_client.commit().await?; - Ok(()) - } else { - db_client.rollback().await?; - Err(RustPSQLDriverError::RustPyError(py_err)) - }; - + Ok(()) + } else { + write_conn_g.rollback().await?; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.is_done = true; - std::mem::take(&mut self_.db_client); + self_.conn = None; }); - return exit_result; + return Err(RustPSQLDriverError::RustPyError(py_err)); } + } - Err(RustPSQLDriverError::TransactionClosedError) + pub async fn begin(&mut self) -> PSQLPyResult<()> { + let conn = &self.conn; + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let mut write_conn_g = conn.write().await; + write_conn_g + .start_transaction(self.isolation_level, self.read_variant, self.deferrable) + .await?; + + Ok(()) } - /// Commit the transaction. - /// - /// Execute `COMMIT` command and mark transaction as `done`. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Transaction is not started - /// 2) Transaction is done - /// 3) Cannot execute `COMMIT` command pub async fn commit(&mut self) -> PSQLPyResult<()> { - self.check_is_transaction_ready()?; - if let Some(db_client) = &self.db_client { - db_client.commit().await?; - self.is_done = true; - std::mem::take(&mut self.db_client); - return Ok(()); - } + let conn = self.conn.clone(); + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let mut write_conn_g = conn.write().await; + write_conn_g.commit().await?; - Err(RustPSQLDriverError::TransactionClosedError) + self.conn = None; + + Ok(()) } - /// Execute ROLLBACK command. - /// - /// Run ROLLBACK command and mark the transaction as done. - /// - /// # Errors - /// May return Err Result if: - /// 1) Transaction is not started - /// 2) Transaction is done - /// 3) Can not execute ROLLBACK command pub async fn rollback(&mut self) -> PSQLPyResult<()> { - self.check_is_transaction_ready()?; - if let Some(db_client) = &self.db_client { - db_client.rollback().await?; - self.is_done = true; - std::mem::take(&mut self.db_client); - return Ok(()); - } + let conn = self.conn.clone(); + let Some(conn) = conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let mut write_conn_g = conn.write().await; + write_conn_g.rollback().await?; - Err(RustPSQLDriverError::TransactionClosedError) + self.conn = None; + + Ok(()) } - /// Execute querystring with parameters. - /// - /// It converts incoming parameters to rust readable - /// and then execute the query with them. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Cannot convert python parameters - /// 2) Cannot execute querystring. #[pyo3(signature = (querystring, parameters=None, prepared=None))] pub async fn execute( - self_: Py, + &self, querystring: String, parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) - }); - is_transaction_ready?; - if let Some(db_client) = db_client { - return db_client.execute(querystring, parameters, prepared).await; - } + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; - Err(RustPSQLDriverError::TransactionClosedError) + let read_conn_g = conn.read().await; + read_conn_g.execute(querystring, parameters, prepared).await } - /// Executes a sequence of SQL statements using the simple query protocol. - /// - /// Statements should be separated by semicolons. - /// If an error occurs, execution of the sequence will stop at that point. - /// This is intended for use when, for example, - /// initializing a database schema. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Transaction is closed. - /// 2) Cannot execute querystring. - pub async fn execute_batch(self_: Py, querystring: String) -> PSQLPyResult<()> { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) - }); - is_transaction_ready?; - if let Some(db_client) = db_client { - return db_client.batch_execute(&querystring).await; - } - - Err(RustPSQLDriverError::TransactionClosedError) - } - - /// Fetch result from the database. - /// - /// It converts incoming parameters to rust readable - /// and then execute the query with them. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Cannot convert python parameters - /// 2) Cannot execute querystring. #[pyo3(signature = (querystring, parameters=None, prepared=None))] pub async fn fetch( - self_: Py, + &self, querystring: String, parameters: Option>, prepared: Option, ) -> PSQLPyResult { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) - }); - is_transaction_ready?; - if let Some(db_client) = db_client { - return db_client.execute(querystring, parameters, prepared).await; - } + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; - Err(RustPSQLDriverError::TransactionClosedError) + let read_conn_g = conn.read().await; + read_conn_g.execute(querystring, parameters, prepared).await } - /// Fetch exaclty single row from query. - /// - /// Method doesn't acquire lock on any structure fields. - /// It prepares and caches querystring in the inner Object object. - /// - /// Then execute the query. - /// - /// # Errors - /// May return Err Result if: - /// 1) Transaction is not started - /// 2) Transaction is done already - /// 3) Can not create/retrieve prepared statement - /// 4) Can not execute statement - /// 5) Query returns more than one row - #[pyo3(signature = (querystring, parameters=None, prepared=None))] - pub async fn fetch_row( - self_: Py, - querystring: String, - parameters: Option>, - prepared: Option, - ) -> PSQLPyResult { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) - }); - is_transaction_ready?; - - if let Some(db_client) = db_client { - return db_client.fetch_row(querystring, parameters, prepared).await; - } - - Err(RustPSQLDriverError::TransactionClosedError) - } - /// Execute querystring with parameters and return first value in the first row. - /// - /// It converts incoming parameters to rust readable, - /// executes query with them and returns first row of response. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Cannot convert python parameters - /// 2) Cannot execute querystring. - /// 3) Query returns more than one row #[pyo3(signature = (querystring, parameters=None, prepared=None))] pub async fn fetch_val( - self_: Py, + &self, querystring: String, parameters: Option>, prepared: Option, ) -> PSQLPyResult> { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) - }); - is_transaction_ready?; - if let Some(db_client) = db_client { - return db_client.fetch_val(querystring, parameters, prepared).await; - } + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + + let read_conn_g = conn.read().await; + read_conn_g + .fetch_val(querystring, parameters, prepared) + .await + } - Err(RustPSQLDriverError::TransactionClosedError) + pub async fn execute_batch(&self, querystring: String) -> PSQLPyResult<()> { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + + let read_conn_g = conn.read().await; + read_conn_g.batch_execute(&querystring).await } - /// Execute querystring with parameters. - /// - /// It converts incoming parameters to rust readable - /// and then execute the query with them. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Cannot convert python parameters - /// 2) Cannot execute querystring. + #[pyo3(signature = (querystring, parameters=None, prepared=None))] pub async fn execute_many( - self_: Py, + &self, querystring: String, parameters: Option>>, prepared: Option, ) -> PSQLPyResult<()> { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) - }); - - is_transaction_ready?; - if let Some(db_client) = db_client { - return db_client - .execute_many(querystring, parameters, prepared) - .await; - } - - Err(RustPSQLDriverError::TransactionClosedError) + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + + let read_conn_g = conn.read().await; + read_conn_g + .execute_many(querystring, parameters, prepared) + .await } - /// Start the transaction. - /// - /// Execute `BEGIN` commands and mark transaction as `started`. - /// - /// # Errors - /// - /// May return Err Result if: - /// 1) Transaction is already started. - /// 2) Transaction is done. - /// 3) Cannot execute `BEGIN` command. - pub async fn begin(self_: Py) -> PSQLPyResult<()> { - let (is_started, is_done, isolation_level, read_variant, deferrable, db_client) = - pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - ( - self_.is_started, - self_.is_done, - self_.isolation_level, - self_.read_variant, - self_.deferrable, - self_.db_client.clone(), - ) - }); - - if let Some(db_client) = db_client { - if is_started { - return Err(RustPSQLDriverError::TransactionBeginError( - "Transaction is already started".into(), - )); - } - - if is_done { - return Err(RustPSQLDriverError::TransactionBeginError( - "Transaction is already committed or rolled back".into(), - )); - } - db_client - .start(isolation_level, read_variant, deferrable) - .await?; - pyo3::Python::with_gil(|gil| { - let mut self_ = self_.borrow_mut(gil); - self_.is_started = true; - }); - - return Ok(()); - } - - Err(RustPSQLDriverError::TransactionClosedError) + #[pyo3(signature = (querystring, parameters=None, prepared=None))] + pub async fn fetch_row( + &self, + querystring: String, + parameters: Option>, + prepared: Option, + ) -> PSQLPyResult { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + + let read_conn_g = conn.read().await; + read_conn_g + .fetch_row(querystring, parameters, prepared) + .await } - /// Create new SAVEPOINT. - /// - /// Execute SAVEPOINT and - /// add it to the transaction `rollback_savepoint` `HashSet` - /// - /// # Errors - /// May return Err Result if: - /// 1) Transaction is not started - /// 2) Transaction is done - /// 3) Specified savepoint name is exists - /// 4) Can not execute SAVEPOINT command - pub async fn create_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { - let (is_transaction_ready, is_savepoint_name_exists, db_client) = - pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - ( - self_.check_is_transaction_ready(), - self_.savepoints_map.contains(&savepoint_name), - self_.db_client.clone(), - ) - }); - - if let Some(db_client) = db_client { - is_transaction_ready?; - - if is_savepoint_name_exists { - return Err(RustPSQLDriverError::TransactionSavepointError(format!( - "SAVEPOINT name {savepoint_name} is already taken by this transaction", - ))); - } - db_client - .batch_execute(format!("SAVEPOINT {savepoint_name}").as_str()) - .await?; + pub async fn create_savepoint(&mut self, savepoint_name: String) -> PSQLPyResult<()> { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; - pyo3::Python::with_gil(|gil| { - self_.borrow_mut(gil).savepoints_map.insert(savepoint_name); - }); - return Ok(()); - } + let read_conn_g = conn.read().await; + read_conn_g + .batch_execute(format!("SAVEPOINT {savepoint_name}").as_str()) + .await?; - Err(RustPSQLDriverError::TransactionClosedError) + Ok(()) } - /// Execute RELEASE SAVEPOINT. - /// - /// Run RELEASE SAVEPOINT command. - /// - /// # Errors - /// May return Err Result if: - /// 1) Transaction is not started - /// 2) Transaction is done - /// 3) Specified savepoint name doesn't exists - /// 4) Can not execute RELEASE SAVEPOINT command - pub async fn release_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { - let (is_transaction_ready, is_savepoint_name_exists, db_client) = - pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - ( - self_.check_is_transaction_ready(), - self_.savepoints_map.contains(&savepoint_name), - self_.db_client.clone(), - ) - }); - - if let Some(db_client) = db_client { - is_transaction_ready?; - if !is_savepoint_name_exists { - return Err(RustPSQLDriverError::TransactionSavepointError( - "Don't have rollback with this name".into(), - )); - } - db_client - .batch_execute(format!("RELEASE SAVEPOINT {savepoint_name}").as_str()) - .await?; + pub async fn release_savepoint(&mut self, savepoint_name: String) -> PSQLPyResult<()> { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; - pyo3::Python::with_gil(|gil| { - self_.borrow_mut(gil).savepoints_map.remove(&savepoint_name); - }); - return Ok(()); - } + let read_conn_g = conn.read().await; + read_conn_g + .batch_execute(format!("RELEASE SAVEPOINT {savepoint_name}").as_str()) + .await?; - Err(RustPSQLDriverError::TransactionClosedError) + Ok(()) } - /// ROLLBACK to the specified savepoint - /// - /// Execute ROLLBACK TO SAVEPOINT . - /// - /// # Errors - /// May return Err Result if: - /// 1) Transaction is not started - /// 2) Transaction is done - /// 3) Specified savepoint name doesn't exist - /// 4) Can not execute ROLLBACK TO SAVEPOINT command - pub async fn rollback_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { - let (is_transaction_ready, is_savepoint_name_exists, db_client) = - pyo3::Python::with_gil(|gil| { - let self_ = self_.borrow(gil); - ( - self_.check_is_transaction_ready(), - self_.savepoints_map.contains(&savepoint_name), - self_.db_client.clone(), - ) - }); - - if let Some(db_client) = db_client { - is_transaction_ready?; - if !is_savepoint_name_exists { - return Err(RustPSQLDriverError::TransactionSavepointError( - "Don't have rollback with this name".into(), - )); - } - db_client - .batch_execute(format!("ROLLBACK TO SAVEPOINT {savepoint_name}").as_str()) - .await?; + pub async fn rollback_savepoint(&mut self, savepoint_name: String) -> PSQLPyResult<()> { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; - pyo3::Python::with_gil(|gil| { - self_.borrow_mut(gil).savepoints_map.remove(&savepoint_name); - }); - return Ok(()); - } + let read_conn_g = conn.read().await; + read_conn_g + .batch_execute(format!("ROLLBACK TO SAVEPOINT {savepoint_name}").as_str()) + .await?; - Err(RustPSQLDriverError::TransactionClosedError) + Ok(()) } - /// Execute querystrings with parameters and return all results. - /// - /// Create pipeline of queries. + + /// Create new cursor object. /// /// # Errors - /// - /// May return Err Result if: - /// 1) Cannot convert python parameters - /// 2) Cannot execute any of querystring. + /// May return Err Result if db_client is None + #[pyo3(signature = ( + querystring, + parameters=None, + fetch_number=None, + scroll=None, + prepared=None, + ))] + pub fn cursor( + &self, + querystring: String, + parameters: Option>, + fetch_number: Option, + scroll: Option, + prepared: Option, + ) -> PSQLPyResult { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + Ok(Cursor::new( + conn.clone(), + self.pg_config.clone(), + querystring, + parameters, + fetch_number.unwrap_or(10), + scroll, + prepared, + )) + } + #[pyo3(signature = (queries=None, prepared=None))] pub async fn pipeline<'py>( self_: Py, queries: Option>, prepared: Option, ) -> PSQLPyResult> { - let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { + let db_client = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); - (self_.check_is_transaction_ready(), self_.db_client.clone()) + self_.conn.clone() }); - is_transaction_ready?; - if let Some(db_client) = db_client { + let conn_read_g = db_client.read().await; let mut futures = vec![]; if let Some(queries) = queries { let gil_result = pyo3::Python::with_gil(|gil| -> PyResult<()> { @@ -672,7 +352,7 @@ impl Transaction { Ok(param) => Some(param.into()), Err(_) => None, }; - futures.push(db_client.execute(querystring, params, prepared)); + futures.push(conn_read_g.execute(querystring, params, prepared)); } Ok(()) }); @@ -691,41 +371,6 @@ impl Transaction { Err(RustPSQLDriverError::TransactionClosedError) } - /// Create new cursor object. - /// - /// # Errors - /// May return Err Result if db_client is None - #[pyo3(signature = ( - querystring, - parameters=None, - fetch_number=None, - scroll=None, - prepared=None, - ))] - pub fn cursor( - &self, - querystring: String, - parameters: Option>, - fetch_number: Option, - scroll: Option, - prepared: Option, - ) -> PSQLPyResult { - if let Some(db_client) = &self.db_client { - return Ok(Cursor::new( - db_client.clone(), - self.pg_config.clone(), - querystring, - parameters, - "cur_name".into(), - fetch_number.unwrap_or(10), - scroll, - prepared, - )); - } - - Err(RustPSQLDriverError::TransactionClosedError) - } - /// Perform binary copy to postgres table. /// /// # Errors @@ -740,7 +385,7 @@ impl Transaction { columns: Option>, schema_name: Option, ) -> PSQLPyResult { - let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); + let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).conn.clone()); let mut table_name = quote_ident(&table_name); if let Some(schema_name) = schema_name { table_name = format!("{}.{}", quote_ident(&schema_name), table_name); @@ -754,7 +399,7 @@ impl Transaction { let copy_qs = format!("COPY {table_name}{formated_columns} FROM STDIN (FORMAT binary)"); if let Some(db_client) = db_client { - let mut psql_bytes: BytesMut = Python::with_gil(|gil| { + let mut psql_bytes: BytesMut = pyo3::Python::with_gil(|gil| { let possible_py_buffer: Result, PyErr> = source.extract::>(gil); if let Ok(py_buffer) = possible_py_buffer { @@ -773,7 +418,8 @@ impl Transaction { )) })?; - let sink = db_client.copy_in(©_qs).await?; + let read_conn_g = db_client.read().await; + let sink = read_conn_g.copy_in(©_qs).await?; let writer = BinaryCopyInWriter::new_empty_buffer(sink, &[]); pin_mut!(writer); writer.as_mut().write_raw_bytes(&mut psql_bytes).await?; diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 15ca4123..e3c0a1f9 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -6,9 +6,10 @@ use postgres_openssl::MakeTlsConnector; use pyo3::{types::PyAnyMethods, Py, PyAny, Python}; use tokio_postgres::{Config, NoTls}; -use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; - -use super::common_options::{self, LoadBalanceHosts, SslMode, TargetSessionAttrs}; +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + options::{LoadBalanceHosts, SslMode, TargetSessionAttrs}, +}; /// Create new config. /// @@ -190,7 +191,7 @@ pub fn build_tls( builder.build(), ))); } else if let Some(ssl_mode) = ssl_mode { - if *ssl_mode == common_options::SslMode::Require { + if *ssl_mode == SslMode::Require { let mut builder = SslConnector::builder(SslMethod::tls())?; builder.set_verify(SslVerifyMode::NONE); return Ok(ConfiguredTLS::TlsConnector(MakeTlsConnector::new( diff --git a/src/lib.rs b/src/lib.rs index 33ec678c..0eaac910 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,11 +39,11 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; - pymod.add_class::()?; - pymod.add_class::()?; - pymod.add_class::()?; - pymod.add_class::()?; - pymod.add_class::()?; + pymod.add_class::()?; + pymod.add_class::()?; + pymod.add_class::()?; + pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; add_module(py, pymod, "extra_types", extra_types_module)?;