From 6a38f1672de224821b695789e91b16ea644f80c5 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Fri, 17 Jan 2025 19:44:06 +0100 Subject: [PATCH 01/17] init Signed-off-by: chandr-andr (Kiselev Aleksandr) --- Cargo.lock | 2 + Cargo.toml | 2 + python/psqlpy/__init__.py | 2 + python/psqlpy/_internal/__init__.pyi | 3 + src/driver/connection_pool.rs | 155 +++++++++++++---- src/driver/connection_pool_builder.rs | 39 ++--- src/driver/listener.rs | 233 ++++++++++++++++++++++++++ src/driver/mod.rs | 1 + src/driver/transaction_options.rs | 8 + src/driver/utils.rs | 53 +++++- src/lib.rs | 1 + 11 files changed, 439 insertions(+), 60 deletions(-) create mode 100644 src/driver/listener.rs diff --git a/Cargo.lock b/Cargo.lock index ce6f31cf..ad9abd60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1004,6 +1004,8 @@ dependencies = [ "chrono", "chrono-tz", "deadpool-postgres", + "futures", + "futures-channel", "futures-util", "geo-types", "itertools", diff --git a/Cargo.toml b/Cargo.toml index 0cdc86fe..6890e1b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,3 +59,5 @@ pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = "psqlpy", features = [ "postgres", ] } +futures-channel = "0.3.31" +futures = "0.3.31" diff --git a/python/psqlpy/__init__.py b/python/psqlpy/__init__.py index 0fb00d10..b0ec91ac 100644 --- a/python/psqlpy/__init__.py +++ b/python/psqlpy/__init__.py @@ -6,6 +6,7 @@ Cursor, IsolationLevel, KeepaliveConfig, + Listener, LoadBalanceHosts, QueryResult, ReadVariant, @@ -25,6 +26,7 @@ "Cursor", "IsolationLevel", "KeepaliveConfig", + "Listener", "LoadBalanceHosts", "QueryResult", "ReadVariant", diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index ff960651..ab0df63b 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1748,3 +1748,6 @@ class ConnectionPoolBuilder: ### Returns: `ConnectionPoolBuilder` """ + +class Listener: + """Result.""" diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 8f5ea984..efd277ed 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,10 +1,10 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; -use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; -use postgres_openssl::MakeTlsConnector; -use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; -use std::{sync::Arc, vec}; -use tokio_postgres::NoTls; +use futures::{stream, FutureExt, StreamExt, TryStreamExt}; +use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny, Python}; +use std::{sync::Arc, time::Duration, vec}; +use tokio::time::sleep; +use tokio_postgres::{Config, NoTls}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -13,9 +13,10 @@ use crate::{ }; use super::{ - common_options::{self, ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, + common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, connection::Connection, - utils::build_connection_config, + listener::Listener, + utils::{build_connection_config, build_manager, build_tls, ConfiguredTLS}, }; /// Make new connection pool. @@ -77,7 +78,6 @@ pub fn connect( load_balance_hosts: Option, ssl_mode: Option, ca_file: Option, - max_db_pool_size: Option, conn_recycling_method: Option, ) -> RustPSQLDriverPyResult { @@ -126,33 +126,25 @@ pub fn connect( }; } - let mgr: Manager; - if let Some(ca_file) = ca_file { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_ca_file(ca_file)?; - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(pg_config, tls_connector, mgr_config); - } else if let Some(ssl_mode) = ssl_mode { - if ssl_mode == common_options::SslMode::Require { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_verify(SslVerifyMode::NONE); - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(pg_config, tls_connector, mgr_config); - } else { - mgr = Manager::from_config(pg_config, NoTls, mgr_config); - } - } else { - mgr = Manager::from_config(pg_config, NoTls, mgr_config); - } + let mgr: Manager = build_manager( + mgr_config, + pg_config.clone(), + build_tls(&ca_file, ssl_mode)?, + ); let mut db_pool_builder = Pool::builder(mgr); if let Some(max_db_pool_size) = max_db_pool_size { db_pool_builder = db_pool_builder.max_size(max_db_pool_size); } - let db_pool = db_pool_builder.build()?; + let pool = db_pool_builder.build()?; - Ok(ConnectionPool(db_pool)) + Ok(ConnectionPool { + pool, + pg_config, + ca_file, + ssl_mode, + }) } #[pyclass] @@ -212,8 +204,31 @@ impl ConnectionPoolStatus { } } +// #[pyclass(subclass)] +// pub struct ConnectionPool(pub Pool); #[pyclass(subclass)] -pub struct ConnectionPool(pub Pool); +pub struct ConnectionPool { + pool: Pool, + pg_config: Config, + ca_file: Option, + ssl_mode: Option, +} + +impl ConnectionPool { + pub fn build( + pool: Pool, + pg_config: Config, + ca_file: Option, + ssl_mode: Option, + ) -> Self { + ConnectionPool { + pool, + pg_config, + ca_file, + ssl_mode, + } + } +} #[pymethods] impl ConnectionPool { @@ -333,7 +348,7 @@ impl ConnectionPool { #[must_use] pub fn status(&self) -> ConnectionPoolStatus { - let inner_status = self.0.status(); + let inner_status = self.pool.status(); ConnectionPoolStatus::new( inner_status.max_size, @@ -344,7 +359,7 @@ impl ConnectionPool { } pub fn resize(&self, new_max_size: usize) { - self.0.resize(new_max_size); + self.pool.resize(new_max_size); } /// Execute querystring with parameters. @@ -361,7 +376,7 @@ impl ConnectionPool { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone()); + let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); let db_pool_manager = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) }) @@ -430,7 +445,7 @@ impl ConnectionPool { parameters: Option>, prepared: Option, ) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone()); + let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); let db_pool_manager = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) }) @@ -484,7 +499,75 @@ impl ConnectionPool { #[must_use] pub fn acquire(&self) -> Connection { - Connection::new(None, Some(self.0.clone())) + Connection::new(None, Some(self.pool.clone())) + } + + pub async fn add_listener( + self_: pyo3::Py, + callback: Py, + ) -> RustPSQLDriverPyResult { + let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { + let b_gil = self_.borrow(gil); + ( + b_gil.pg_config.clone(), + b_gil.ca_file.clone(), + b_gil.ssl_mode, + ) + }); + + // let tls_ = build_tls(&ca_file, Some(SslMode::Disable)).unwrap(); + + // match tls_ { + // ConfiguredTLS::NoTls => { + // let a = pg_config.connect(NoTls).await.unwrap(); + // }, + // ConfiguredTLS::TlsConnector(connector) => { + // let a = pg_config.connect(connector).await.unwrap(); + // } + // } + + // let (client, mut connection) = tokio_runtime() + // .spawn(async move { pg_config.connect(NoTls).await.unwrap() }) + // .await?; + + // // Make transmitter and receiver. + // let (tx, mut rx) = futures_channel::mpsc::unbounded(); + // let stream = + // stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); + // let connection = stream.forward(tx).map(|r| r.unwrap()); + // tokio_runtime().spawn(connection); + + // // Wait for notifications in separate thread. + // tokio_runtime().spawn(async move { + // client + // .batch_execute( + // "LISTEN test_notifications; + // LISTEN test_notifications2;", + // ) + // .await + // .unwrap(); + + // loop { + // let next_element = rx.next().await; + // client.batch_execute("LISTEN test_notifications3;").await.unwrap(); + // match next_element { + // Some(n) => { + // match n { + // tokio_postgres::AsyncMessage::Notification(n) => { + // Python::with_gil(|gil| { + // callback.call0(gil); + // }); + // println!("Notification {:?}", n); + // }, + // _ => {println!("in_in {:?}", n)} + // } + // }, + // _ => {println!("in {:?}", next_element)} + // } + // } + // }); + + Ok(Listener::new(pg_config, ca_file, ssl_mode)) } /// Return new single connection. @@ -492,7 +575,7 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub async fn connection(self_: pyo3::Py) -> RustPSQLDriverPyResult { - let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone()); + let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone()); let db_connection = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) @@ -507,7 +590,7 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub fn close(&self) { - let db_pool = self.0.clone(); + let db_pool = self.pool.clone(); db_pool.close(); } diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index bb2047b4..963da555 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -1,14 +1,15 @@ use std::{net::IpAddr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; -use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; -use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pymethods, Py, Python}; -use tokio_postgres::NoTls; use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; -use super::{common_options, connection_pool::ConnectionPool}; +use super::{ + common_options, + connection_pool::ConnectionPool, + utils::{build_manager, build_tls}, +}; #[pyclass] pub struct ConnectionPoolBuilder { @@ -49,24 +50,11 @@ impl ConnectionPoolBuilder { }; }; - let mgr: Manager; - if let Some(ca_file) = &self.ca_file { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_ca_file(ca_file)?; - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config); - } else if let Some(ssl_mode) = self.ssl_mode { - if ssl_mode == common_options::SslMode::Require { - let mut builder = SslConnector::builder(SslMethod::tls())?; - builder.set_verify(SslVerifyMode::NONE); - let tls_connector = MakeTlsConnector::new(builder.build()); - mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config); - } else { - mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config); - } - } else { - mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config); - } + let mgr: Manager = build_manager( + mgr_config, + self.config.clone(), + build_tls(&self.ca_file, self.ssl_mode)?, + ); let mut db_pool_builder = Pool::builder(mgr); if let Some(max_db_pool_size) = self.max_db_pool_size { @@ -75,7 +63,12 @@ impl ConnectionPoolBuilder { let db_pool = db_pool_builder.build()?; - Ok(ConnectionPool(db_pool)) + Ok(ConnectionPool::build( + db_pool, + self.config.clone(), + self.ca_file.clone(), + self.ssl_mode, + )) } /// Set ca_file for ssl_mode in PostgreSQL. diff --git a/src/driver/listener.rs b/src/driver/listener.rs new file mode 100644 index 00000000..ec31abe9 --- /dev/null +++ b/src/driver/listener.rs @@ -0,0 +1,233 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::Arc, + task::Poll, +}; + +use futures::{stream, FutureExt, StreamExt, TryStreamExt}; +use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; +use postgres_openssl::{MakeTlsConnector, TlsStream}; +use pyo3::{pyclass, pymethods, Py, PyAny, PyObject, Python}; +use tokio::{sync::RwLock, task::AbortHandle}; +use tokio_postgres::{AsyncMessage, Client, Config, Connection, NoTls, Socket}; + +use crate::{ + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + runtime::{rustdriver_future, tokio_runtime}, +}; + +use super::{ + common_options::SslMode, + transaction_options::{ + IsolationLevel, ListenerTransactionConfig, ReadVariant, SynchronousCommit, + }, + utils::{build_tls, ConfiguredTLS}, +}; + +#[pyclass] +pub struct Listener { + // name: String, + pg_config: Config, + ca_file: Option, + ssl_mode: Option, + // transaction_config: Option, + channel_callbacks: HashMap>>, + listen_abort_handler: Option, + client: Option>, + receiver: Option>>>, + is_listened: Arc>, +} + +impl Listener { + pub fn new( + // name: String, + pg_config: Config, + ca_file: Option, + ssl_mode: Option, + // transaction_config: Option, + ) -> Self { + Listener { + // name: name, + pg_config: pg_config, + ca_file: ca_file, + // transaction_config: transaction_config, + ssl_mode: ssl_mode, + channel_callbacks: Default::default(), + listen_abort_handler: Default::default(), + client: Default::default(), + receiver: Default::default(), + is_listened: Arc::new(RwLock::new(false)), + } + } +} + +#[pymethods] +impl Listener { + #[must_use] + fn __aiter__(slf: Py) -> Py { + slf + } + + fn __await__(slf: Py) -> Py { + slf + } + + async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { + Ok(slf) + } + + async fn __aexit__<'a>( + slf: Py, + _exception_type: Py, + exception: Py, + _traceback: Py, + ) -> RustPSQLDriverPyResult<()> { + Ok(()) + } + + fn __anext__(&self) -> RustPSQLDriverPyResult> { + let Some(client) = self.client.clone() else { + return Err(RustPSQLDriverError::BaseConnectionError("test".into())); + }; + let Some(receiver) = self.receiver.clone() else { + return Err(RustPSQLDriverError::BaseConnectionError("test".into())); + }; + let is_listened = self.is_listened.clone(); + let py_future = Python::with_gil(move |gil| { + rustdriver_future(gil, async move { + let mut write_is_listened = is_listened.write().await; + if write_is_listened.eq(&false) { + println!("here1"); + client + .batch_execute( + "LISTEN test_notifications; + LISTEN test_notifications2;", + ) + .await + .unwrap(); + + *write_is_listened = true; + } + let mut write_receiver = receiver.write().await; + let next_element = write_receiver.next().await; + println!("here2"); + + match next_element { + Some(n) => match n { + tokio_postgres::AsyncMessage::Notification(n) => { + println!("Notification {:?}", n); + return Ok(()); + } + _ => { + println!("in_in {:?}", n) + } + }, + _ => { + println!("in {:?}", next_element) + } + } + + Ok(()) + }) + }); + + Ok(Some(py_future?)) + } + + async fn startup(&mut self) -> RustPSQLDriverPyResult<()> { + let tls_ = build_tls(&self.ca_file.clone(), self.ssl_mode)?; + + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_verify(SslVerifyMode::NONE); + + let pg_config = self.pg_config.clone(); + let connect_future = async move { + match tls_ { + ConfiguredTLS::NoTls => { + return pg_config + .connect(MakeTlsConnector::new(builder.build())) + .await; + } + ConfiguredTLS::TlsConnector(connector) => { + return pg_config.connect(connector).await; + } + } + }; + + let (client, mut connection) = tokio_runtime().spawn(connect_future).await??; + + let (transmitter, receiver) = futures_channel::mpsc::unbounded::(); + + let stream = + stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); + + let connection = stream.forward(transmitter).map(|r| r.unwrap()); + tokio_runtime().spawn(connection); + + self.receiver = Some(Arc::new(RwLock::new(receiver))); + self.client = Some(Arc::new(client)); + + Ok(()) + } + + fn add_callback(&mut self, channel: String, callback: Py) -> RustPSQLDriverPyResult<()> { + match self.channel_callbacks.entry(channel) { + Entry::Vacant(e) => { + e.insert(vec![callback]); + } + Entry::Occupied(mut e) => { + e.get_mut().push(callback); + } + }; + + Ok(()) + } + + async fn listen(&mut self) -> RustPSQLDriverPyResult<()> { + let Some(client) = self.client.clone() else { + return Err(RustPSQLDriverError::BaseConnectionError("test".into())); + }; + let Some(receiver) = self.receiver.clone() else { + return Err(RustPSQLDriverError::BaseConnectionError("test".into())); + }; + + let jh = tokio_runtime().spawn(async move { + client + .batch_execute( + "LISTEN test_notifications; + LISTEN test_notifications2;", + ) + .await + .unwrap(); + + loop { + let mut write_receiver = receiver.write().await; + let next_element = write_receiver.next().await; + client + .batch_execute("LISTEN test_notifications3;") + .await + .unwrap(); + match next_element { + Some(n) => match n { + tokio_postgres::AsyncMessage::Notification(n) => { + println!("Notification {:?}", n); + } + _ => { + println!("in_in {:?}", n) + } + }, + _ => { + println!("in {:?}", next_element) + } + } + } + }); + + let abj = jh.abort_handle(); + + self.listen_abort_handler = Some(abj); + + Ok(()) + } +} diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 183a0343..578bf2cd 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -3,6 +3,7 @@ pub mod connection; pub mod connection_pool; pub mod connection_pool_builder; pub mod cursor; +pub mod listener; pub mod transaction; pub mod transaction_options; pub mod utils; diff --git a/src/driver/transaction_options.rs b/src/driver/transaction_options.rs index 51467761..281b9a71 100644 --- a/src/driver/transaction_options.rs +++ b/src/driver/transaction_options.rs @@ -71,3 +71,11 @@ impl SynchronousCommit { } } } + +#[derive(Clone, Copy, PartialEq)] +pub struct ListenerTransactionConfig { + isolation_level: Option, + read_variant: Option, + deferrable: Option, + synchronous_commit: Option, +} diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 33bf0aea..41759216 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -1,8 +1,13 @@ use std::{str::FromStr, time::Duration}; +use deadpool_postgres::{Manager, ManagerConfig}; +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; +use postgres_openssl::MakeTlsConnector; +use tokio_postgres::{Config, NoTls}; + use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; -use super::common_options::{LoadBalanceHosts, SslMode, TargetSessionAttrs}; +use super::common_options::{self, LoadBalanceHosts, SslMode, TargetSessionAttrs}; /// Create new config. /// @@ -163,3 +168,49 @@ pub fn build_connection_config( Ok(pg_config) } + +pub enum ConfiguredTLS { + NoTls, + TlsConnector(MakeTlsConnector), +} + +pub fn build_tls( + ca_file: &Option, + ssl_mode: Option, +) -> RustPSQLDriverPyResult { + if let Some(ca_file) = ca_file { + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_ca_file(ca_file)?; + return Ok(ConfiguredTLS::TlsConnector(MakeTlsConnector::new( + builder.build(), + ))); + } else if let Some(ssl_mode) = ssl_mode { + if ssl_mode == common_options::SslMode::Require { + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_verify(SslVerifyMode::NONE); + return Ok(ConfiguredTLS::TlsConnector(MakeTlsConnector::new( + builder.build(), + ))); + } + } + + Ok(ConfiguredTLS::NoTls) +} + +pub fn build_manager( + mgr_config: ManagerConfig, + pg_config: Config, + configured_tls: ConfiguredTLS, +) -> Manager { + let mgr: Manager; + match configured_tls { + ConfiguredTLS::NoTls => { + mgr = Manager::from_config(pg_config, NoTls, mgr_config); + } + ConfiguredTLS::TlsConnector(connector) => { + mgr = Manager::from_config(pg_config, connector, mgr_config); + } + } + + return mgr; +} diff --git a/src/lib.rs b/src/lib.rs index edda3119..a535f4c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; From e0090ec3fa26d50599dc142215a1ce107fa6457c Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Fri, 17 Jan 2025 19:44:39 +0100 Subject: [PATCH 02/17] init Signed-off-by: chandr-andr (Kiselev Aleksandr) --- src/driver/connection_pool.rs | 13 ++++++------- src/driver/listener.rs | 30 +++++++++++++----------------- src/driver/utils.rs | 4 ++-- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index efd277ed..23faaa68 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,10 +1,9 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; -use futures::{stream, FutureExt, StreamExt, TryStreamExt}; -use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny, Python}; -use std::{sync::Arc, time::Duration, vec}; -use tokio::time::sleep; -use tokio_postgres::{Config, NoTls}; +use futures::{FutureExt, StreamExt, TryStreamExt}; +use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; +use std::{sync::Arc, vec}; +use tokio_postgres::Config; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -16,7 +15,7 @@ use super::{ common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, connection::Connection, listener::Listener, - utils::{build_connection_config, build_manager, build_tls, ConfiguredTLS}, + utils::{build_connection_config, build_manager, build_tls}, }; /// Make new connection pool. @@ -215,7 +214,7 @@ pub struct ConnectionPool { } impl ConnectionPool { - pub fn build( + #[must_use] pub fn build( pool: Pool, pg_config: Config, ca_file: Option, diff --git a/src/driver/listener.rs b/src/driver/listener.rs index ec31abe9..55459f48 100644 --- a/src/driver/listener.rs +++ b/src/driver/listener.rs @@ -1,16 +1,15 @@ use std::{ collections::{hash_map::Entry, HashMap}, sync::Arc, - task::Poll, }; use futures::{stream, FutureExt, StreamExt, TryStreamExt}; -use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; +use futures_channel::mpsc::UnboundedReceiver; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; -use postgres_openssl::{MakeTlsConnector, TlsStream}; +use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pymethods, Py, PyAny, PyObject, Python}; use tokio::{sync::RwLock, task::AbortHandle}; -use tokio_postgres::{AsyncMessage, Client, Config, Connection, NoTls, Socket}; +use tokio_postgres::{AsyncMessage, Client, Config}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -19,9 +18,6 @@ use crate::{ use super::{ common_options::SslMode, - transaction_options::{ - IsolationLevel, ListenerTransactionConfig, ReadVariant, SynchronousCommit, - }, utils::{build_tls, ConfiguredTLS}, }; @@ -40,7 +36,7 @@ pub struct Listener { } impl Listener { - pub fn new( + #[must_use] pub fn new( // name: String, pg_config: Config, ca_file: Option, @@ -49,10 +45,10 @@ impl Listener { ) -> Self { Listener { // name: name, - pg_config: pg_config, - ca_file: ca_file, + pg_config, + ca_file, // transaction_config: transaction_config, - ssl_mode: ssl_mode, + ssl_mode, channel_callbacks: Default::default(), listen_abort_handler: Default::default(), client: Default::default(), @@ -116,15 +112,15 @@ impl Listener { match next_element { Some(n) => match n { tokio_postgres::AsyncMessage::Notification(n) => { - println!("Notification {:?}", n); + println!("Notification {n:?}"); return Ok(()); } _ => { - println!("in_in {:?}", n) + println!("in_in {n:?}"); } }, _ => { - println!("in {:?}", next_element) + println!("in {next_element:?}"); } } @@ -211,14 +207,14 @@ impl Listener { match next_element { Some(n) => match n { tokio_postgres::AsyncMessage::Notification(n) => { - println!("Notification {:?}", n); + println!("Notification {n:?}"); } _ => { - println!("in_in {:?}", n) + println!("in_in {n:?}"); } }, _ => { - println!("in {:?}", next_element) + println!("in {next_element:?}"); } } } diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 41759216..6502c1dc 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -197,7 +197,7 @@ pub fn build_tls( Ok(ConfiguredTLS::NoTls) } -pub fn build_manager( +#[must_use] pub fn build_manager( mgr_config: ManagerConfig, pg_config: Config, configured_tls: ConfiguredTLS, @@ -212,5 +212,5 @@ pub fn build_manager( } } - return mgr; + mgr } From 6b22424eaeaa980411ad198f064effdff1f87e5f Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 18 Jan 2025 23:02:39 +0100 Subject: [PATCH 03/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- python/psqlpy/__init__.py | 2 + python/psqlpy/_internal/__init__.pyi | 4 + src/common.rs | 7 +- src/driver/connection.rs | 172 +++++------ src/driver/connection_pool.rs | 58 +--- src/driver/cursor.rs | 9 +- src/driver/listener.rs | 421 ++++++++++++++++++++------- src/driver/transaction.rs | 10 +- src/driver/utils.rs | 16 + src/exceptions/python_errors.rs | 10 + src/exceptions/rust_errors.rs | 22 +- src/lib.rs | 1 + 12 files changed, 472 insertions(+), 260 deletions(-) diff --git a/python/psqlpy/__init__.py b/python/psqlpy/__init__.py index b0ec91ac..edbdb53b 100644 --- a/python/psqlpy/__init__.py +++ b/python/psqlpy/__init__.py @@ -7,6 +7,7 @@ IsolationLevel, KeepaliveConfig, Listener, + ListenerNotification, LoadBalanceHosts, QueryResult, ReadVariant, @@ -27,6 +28,7 @@ "IsolationLevel", "KeepaliveConfig", "Listener", + "ListenerNotification", "LoadBalanceHosts", "QueryResult", "ReadVariant", diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index ab0df63b..c05bfc8e 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1751,3 +1751,7 @@ class ConnectionPoolBuilder: class Listener: """Result.""" + + +class ListenerNotification: + """Result.""" diff --git a/src/common.rs b/src/common.rs index 0a4382ec..f4772860 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,13 +1,10 @@ -use deadpool_postgres::Object; use pyo3::{ types::{PyAnyMethods, PyModule, PyModuleMethods}, Bound, PyAny, PyResult, Python, }; use crate::{ - exceptions::rust_errors::RustPSQLDriverPyResult, - query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{convert_parameters, PythonDTO, QueryParameter}, + driver::connection::InnerConnection, exceptions::rust_errors::RustPSQLDriverPyResult, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, value_converter::{convert_parameters, PythonDTO, QueryParameter} }; /// Add new module to the parent one. @@ -55,7 +52,7 @@ pub trait ObjectQueryTrait { ) -> impl std::future::Future> + Send; } -impl ObjectQueryTrait for Object { +impl ObjectQueryTrait for InnerConnection { async fn psqlpy_query_one( &self, querystring: String, diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 97dc66a1..b3add461 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -1,9 +1,10 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use deadpool_postgres::{Object, Pool}; use futures_util::pin_mut; +use postgres_types::ToSql; use pyo3::{buffer::PyBuffer, pyclass, pymethods, Py, PyAny, PyErr, Python}; use std::{collections::HashSet, sync::Arc, vec}; -use tokio_postgres::binary_copy::BinaryCopyInWriter; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, Client, CopyInSink, Row, Statement, ToStatement}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -19,110 +20,115 @@ use super::{ transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; -/// Format OPTS parameter for Postgres COPY command. -/// -/// # Errors -/// May return Err Result if cannot format parameter. -#[allow(clippy::too_many_arguments)] -pub fn _format_copy_opts( - format: Option, - freeze: Option, - delimiter: Option, - null: Option, - header: Option, - quote: Option, - escape: Option, - force_quote: Option>, - force_not_null: Option>, - force_null: Option>, - encoding: Option, -) -> RustPSQLDriverPyResult { - let mut opts: Vec = vec![]; - - if let Some(format) = format { - opts.push(format!("FORMAT {format}")); - } +pub enum InnerConnection { + PoolConn(Object), + SingleConn(Client), +} - if let Some(freeze) = freeze { - if freeze { - opts.push("FREEZE TRUE".into()); - } else { - opts.push("FREEZE FALSE".into()); +impl InnerConnection { + pub async fn prepare_cached( + &self, + query: &str + ) -> RustPSQLDriverPyResult { + match self { + InnerConnection::PoolConn(pconn) => { + return Ok(pconn.prepare_cached(query).await?) + } + InnerConnection::SingleConn(sconn) => { + return Ok(sconn.prepare(query).await?) + } } } - - if let Some(delimiter) = delimiter { - opts.push(format!("DELIMITER {delimiter}")); - } - - if let Some(null) = null { - opts.push(format!("NULL {}", quote_ident(&null))); - } - - if let Some(header) = header { - opts.push(format!("HEADER {header}")); - } - - if let Some(quote) = quote { - opts.push(format!("QUOTE {quote}")); - } - - if let Some(escape) = escape { - opts.push(format!("ESCAPE {escape}")); - } - - if let Some(force_quote) = force_quote { - let boolean_force_quote: Result = - Python::with_gil(|gil| force_quote.extract::(gil)); - - if let Ok(force_quote) = boolean_force_quote { - if force_quote { - opts.push("FORCE_QUOTE *".into()); + + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> RustPSQLDriverPyResult> + where T: ?Sized + ToStatement { + match self { + InnerConnection::PoolConn(pconn) => { + return Ok(pconn.query(statement, params).await?) } - } else { - let sequence_force_quote: Result, PyErr> = - Python::with_gil(|gil| force_quote.extract::>(gil)); - - if let Ok(force_quote) = sequence_force_quote { - opts.push(format!("FORCE_QUOTE ({})", force_quote.join(", "))); + InnerConnection::SingleConn(sconn) => { + return Ok(sconn.query(statement, params).await?) } - - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "force_quote parameter must be boolean or sequence of str's.".into(), - )); } } - if let Some(force_not_null) = force_not_null { - opts.push(format!("FORCE_NOT_NULL ({})", force_not_null.join(", "))); - } - - if let Some(force_null) = force_null { - opts.push(format!("FORCE_NULL ({})", force_null.join(", "))); + pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { + match self { + InnerConnection::PoolConn(pconn) => { + return Ok(pconn.batch_execute(query).await?) + } + InnerConnection::SingleConn(sconn) => { + return Ok(sconn.batch_execute(query).await?) + } + } } - if let Some(encoding) = encoding { - opts.push(format!("ENCODING {}", quote_ident(&encoding))); + pub async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> RustPSQLDriverPyResult + where T: ?Sized + ToStatement + { + match self { + InnerConnection::PoolConn(pconn) => { + return Ok(pconn.query_one(statement, params).await?) + } + InnerConnection::SingleConn(sconn) => { + return Ok(sconn.query_one(statement, params).await?) + } + } } - if opts.is_empty() { - Ok(String::new()) - } else { - Ok(format!("({})", opts.join(", "))) + pub async fn copy_in( + &self, + statement: &T + ) -> RustPSQLDriverPyResult> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send + { + match self { + InnerConnection::PoolConn(pconn) => { + return Ok(pconn.copy_in(statement).await?) + } + InnerConnection::SingleConn(sconn) => { + return Ok(sconn.copy_in(statement).await?) + } + } } } #[pyclass(subclass)] +#[derive(Clone)] pub struct Connection { - db_client: Option>, + db_client: Option>, db_pool: Option, } impl Connection { #[must_use] - pub fn new(db_client: Option>, db_pool: Option) -> Self { + pub fn new(db_client: Option>, db_pool: Option) -> Self { Connection { db_client, db_pool } } + + pub fn db_client(&self) -> Option> { + return self.db_client.clone() + } + + pub fn db_pool(&self) -> Option { + return self.db_pool.clone() + } +} + +impl Default for Connection { + fn default() -> Self { + Connection::new(None, None) + } } #[pymethods] @@ -145,7 +151,7 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(db_connection)); + self_.db_client = Some(Arc::new(InnerConnection::PoolConn(db_connection))); }); return Ok(self_); } diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 23faaa68..6d090c64 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,6 +1,5 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; -use futures::{FutureExt, StreamExt, TryStreamExt}; use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; use std::{sync::Arc, vec}; use tokio_postgres::Config; @@ -13,7 +12,7 @@ use crate::{ use super::{ common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, - connection::Connection, + connection::{Connection, InnerConnection}, listener::Listener, utils::{build_connection_config, build_manager, build_tls}, }; @@ -503,7 +502,6 @@ impl ConnectionPool { pub async fn add_listener( self_: pyo3::Py, - callback: Py, ) -> RustPSQLDriverPyResult { let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { let b_gil = self_.borrow(gil); @@ -514,58 +512,6 @@ impl ConnectionPool { ) }); - // let tls_ = build_tls(&ca_file, Some(SslMode::Disable)).unwrap(); - - // match tls_ { - // ConfiguredTLS::NoTls => { - // let a = pg_config.connect(NoTls).await.unwrap(); - // }, - // ConfiguredTLS::TlsConnector(connector) => { - // let a = pg_config.connect(connector).await.unwrap(); - // } - // } - - // let (client, mut connection) = tokio_runtime() - // .spawn(async move { pg_config.connect(NoTls).await.unwrap() }) - // .await?; - - // // Make transmitter and receiver. - // let (tx, mut rx) = futures_channel::mpsc::unbounded(); - // let stream = - // stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); - // let connection = stream.forward(tx).map(|r| r.unwrap()); - // tokio_runtime().spawn(connection); - - // // Wait for notifications in separate thread. - // tokio_runtime().spawn(async move { - // client - // .batch_execute( - // "LISTEN test_notifications; - // LISTEN test_notifications2;", - // ) - // .await - // .unwrap(); - - // loop { - // let next_element = rx.next().await; - // client.batch_execute("LISTEN test_notifications3;").await.unwrap(); - // match next_element { - // Some(n) => { - // match n { - // tokio_postgres::AsyncMessage::Notification(n) => { - // Python::with_gil(|gil| { - // callback.call0(gil); - // }); - // println!("Notification {:?}", n); - // }, - // _ => {println!("in_in {:?}", n)} - // } - // }, - // _ => {println!("in {:?}", next_element)} - // } - // } - // }); - Ok(Listener::new(pg_config, ca_file, ssl_mode)) } @@ -581,7 +527,7 @@ impl ConnectionPool { }) .await??; - Ok(Connection::new(Some(Arc::new(db_connection)), None)) + Ok(Connection::new(Some(Arc::new(InnerConnection::PoolConn(db_connection))), None)) } /// Close connection pool. diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 3f8008be..3284bc19 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use deadpool_postgres::Object; use pyo3::{ exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, }; @@ -12,6 +11,8 @@ use crate::{ runtime::rustdriver_future, }; +use super::connection::InnerConnection; + /// Additional implementation for the `Object` type. #[allow(clippy::ref_option)] trait CursorObjectTrait { @@ -27,7 +28,7 @@ trait CursorObjectTrait { async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; } -impl CursorObjectTrait for Object { +impl CursorObjectTrait for InnerConnection { /// Start the cursor. /// /// Execute `DECLARE` command with parameters. @@ -89,7 +90,7 @@ impl CursorObjectTrait for Object { #[pyclass(subclass)] pub struct Cursor { - db_transaction: Option>, + db_transaction: Option>, querystring: String, parameters: Option>, cursor_name: String, @@ -103,7 +104,7 @@ pub struct Cursor { impl Cursor { #[must_use] pub fn new( - db_transaction: Arc, + db_transaction: Arc, querystring: String, parameters: Option>, cursor_name: String, diff --git a/src/driver/listener.rs b/src/driver/listener.rs index 55459f48..7f9dadb7 100644 --- a/src/driver/listener.rs +++ b/src/driver/listener.rs @@ -7,55 +7,175 @@ use futures::{stream, FutureExt, StreamExt, TryStreamExt}; use futures_channel::mpsc::UnboundedReceiver; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; -use pyo3::{pyclass, pymethods, Py, PyAny, PyObject, Python}; -use tokio::{sync::RwLock, task::AbortHandle}; -use tokio_postgres::{AsyncMessage, Client, Config}; +use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python}; +use pyo3_async_runtimes::TaskLocals; +use tokio::{sync::RwLock, task::{AbortHandle, JoinHandle}}; +use tokio_postgres::{AsyncMessage, Config, Notification}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, - runtime::{rustdriver_future, tokio_runtime}, + driver::utils::is_coroutine_function, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, runtime::{rustdriver_future, tokio_runtime} }; use super::{ - common_options::SslMode, - utils::{build_tls, ConfiguredTLS}, + common_options::SslMode, connection::{Connection, InnerConnection}, utils::{build_tls, ConfiguredTLS} }; +struct ChannelCallbacks(HashMap>); + +impl Default for ChannelCallbacks { + fn default() -> Self { + ChannelCallbacks(Default::default()) + } +} + +impl ChannelCallbacks { + fn add_callback(&mut self, channel: String, callback: ListenerCallback) { + match self.0.entry(channel) { + Entry::Vacant(e) => { + e.insert(vec![callback]); + } + Entry::Occupied(mut e) => { + e.get_mut().push(callback); + } + }; + } + + fn retrieve_channel_callbacks(&self, channel: String) -> Option<&Vec> { + self.0.get(&channel) + } + + fn clear_channel_callbacks(&mut self, channel: String) { + self.0.remove(&channel); + } + + fn retrieve_all_channels(&self) -> Vec<&String> { + self.0.keys().collect::>() + } +} + + +#[derive(Clone, Debug)] +#[pyclass] +pub struct ListenerNotification { + process_id: i32, + channel: String, + payload: String, +} + +impl From:: for ListenerNotification { + fn from(value: Notification) -> Self { + ListenerNotification { + process_id: value.process_id(), + channel: String::from(value.channel()), + payload: String::from(value.payload()), + } + } +} + +#[pyclass] +struct ListenerNotificationMsg { + process_id: i32, + channel: String, + payload: String, + connection: Connection, +} + +struct ListenerCallback { + task_locals: Option, + callback: Py, +} + +impl ListenerCallback { + pub fn new( + task_locals: Option, + callback: Py, + ) -> Self { + ListenerCallback { + task_locals, + callback, + } + } + + async fn call( + &self, + lister_notification: ListenerNotification, + connection: Connection, + ) -> RustPSQLDriverPyResult<()> { + let (callback, task_locals) = Python::with_gil(|py| { + if let Some(task_locals) = &self.task_locals { + return (self.callback.clone(), Some(task_locals.clone_ref(py))); + } + (self.callback.clone(), None) + }); + + if let Some(task_locals) = task_locals { + tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { + let future = Python::with_gil(|py| { + let awaitable = callback.call1(py, (lister_notification, connection)).unwrap(); + pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap() + }); + future.await.unwrap(); + })).await?; + }; + + Ok(()) + } +} + #[pyclass] pub struct Listener { - // name: String, pg_config: Config, ca_file: Option, ssl_mode: Option, - // transaction_config: Option, - channel_callbacks: HashMap>>, + channel_callbacks: Arc>, listen_abort_handler: Option, - client: Option>, + connection: Connection, receiver: Option>>>, + listen_query: Arc>, is_listened: Arc>, + is_started: bool, } impl Listener { #[must_use] pub fn new( - // name: String, pg_config: Config, ca_file: Option, ssl_mode: Option, - // transaction_config: Option, ) -> Self { Listener { - // name: name, pg_config, ca_file, - // transaction_config: transaction_config, ssl_mode, channel_callbacks: Default::default(), listen_abort_handler: Default::default(), - client: Default::default(), + connection: Connection::new(None, None), receiver: Default::default(), + listen_query: Default::default(), is_listened: Arc::new(RwLock::new(false)), + is_started: false, } } + + async fn update_listen_query(&self) -> () { + let read_channel_callbacks = self.channel_callbacks.read().await; + + let channels = read_channel_callbacks.retrieve_all_channels(); + + let mut final_query: String = String::default(); + + for channel_name in channels { + final_query.push_str( + format!("LISTEN {};", channel_name).as_str() + ); + } + + let mut write_listen_query = self.listen_query.write().await; + let mut write_is_listened = self.is_listened.write().await; + + write_listen_query.clear(); + write_listen_query.push_str(&final_query); + *write_is_listened = false; + } } #[pymethods] @@ -79,59 +199,45 @@ impl Listener { exception: Py, _traceback: Py, ) -> RustPSQLDriverPyResult<()> { - Ok(()) - } + let (client, is_exception_none, py_err) = + pyo3::Python::with_gil(|gil| { + let self_ = slf.borrow(gil); + ( + self_.connection.db_client(), + exception.is_none(gil), + PyErr::from_value_bound(exception.into_bound(gil)), + ) + }); - fn __anext__(&self) -> RustPSQLDriverPyResult> { - let Some(client) = self.client.clone() else { - return Err(RustPSQLDriverError::BaseConnectionError("test".into())); - }; - let Some(receiver) = self.receiver.clone() else { - return Err(RustPSQLDriverError::BaseConnectionError("test".into())); - }; - let is_listened = self.is_listened.clone(); - let py_future = Python::with_gil(move |gil| { - rustdriver_future(gil, async move { - let mut write_is_listened = is_listened.write().await; - if write_is_listened.eq(&false) { - println!("here1"); - client - .batch_execute( - "LISTEN test_notifications; - LISTEN test_notifications2;", - ) - .await - .unwrap(); - - *write_is_listened = true; - } - let mut write_receiver = receiver.write().await; - let next_element = write_receiver.next().await; - println!("here2"); - - match next_element { - Some(n) => match n { - tokio_postgres::AsyncMessage::Notification(n) => { - println!("Notification {n:?}"); - return Ok(()); - } - _ => { - println!("in_in {n:?}"); - } - }, - _ => { - println!("in {next_element:?}"); - } - } + if client.is_some() { + pyo3::Python::with_gil(|gil| { + let mut self_ = slf.borrow_mut(gil); + std::mem::take(&mut self_.connection); + std::mem::take(&mut self_.receiver); + }); - Ok(()) - }) - }); + if !is_exception_none { + return Err(RustPSQLDriverError::RustPyError(py_err)); + } - Ok(Some(py_future?)) + return Ok(()); + } + + return Err(RustPSQLDriverError::ListenerClosedError) + } + + #[getter] + fn connection(&self) -> Connection { + self.connection.clone() } async fn startup(&mut self) -> RustPSQLDriverPyResult<()> { + if self.is_started { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener is already started".into(), + )); + } + let tls_ = build_tls(&self.ca_file.clone(), self.ssl_mode)?; let mut builder = SslConnector::builder(SslMethod::tls())?; @@ -156,65 +262,134 @@ impl Listener { let (transmitter, receiver) = futures_channel::mpsc::unbounded::(); let stream = - stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); + stream::poll_fn( + move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e), + ); let connection = stream.forward(transmitter).map(|r| r.unwrap()); tokio_runtime().spawn(connection); self.receiver = Some(Arc::new(RwLock::new(receiver))); - self.client = Some(Arc::new(client)); + self.connection = Connection::new( + Some(Arc::new(InnerConnection::SingleConn(client))), + None, + ); + + self.is_started = true; Ok(()) } - fn add_callback(&mut self, channel: String, callback: Py) -> RustPSQLDriverPyResult<()> { - match self.channel_callbacks.entry(channel) { - Entry::Vacant(e) => { - e.insert(vec![callback]); - } - Entry::Occupied(mut e) => { - e.get_mut().push(callback); - } + fn __anext__(&self) -> RustPSQLDriverPyResult> { + let Some(client) = self.connection.db_client() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener doesn't have underlying client, please call startup".into(), + )); + }; + let Some(receiver) = self.receiver.clone() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener doesn't have underlying receiver, please call startup".into(), + )); }; + let is_listened_clone = self.is_listened.clone(); + let listen_query_clone = self.listen_query.clone(); + + let py_future = Python::with_gil(move |gil| { + rustdriver_future(gil, async move { + { + call_listen(&is_listened_clone, &listen_query_clone, &client).await?; + }; + let next_element = { + let mut write_receiver = receiver.write().await; + write_receiver.next().await + }; + + let inner_notification = process_message(next_element)?; + + Ok(inner_notification) + }) + }); + + Ok(Some(py_future?)) + } + + #[pyo3(signature = (channel, callback))] + async fn add_callback( + &mut self, + channel: String, + callback: Py, + ) -> RustPSQLDriverPyResult<()> { + let callback_clone = callback.clone(); + + let is_coro = is_coroutine_function(callback_clone)?; + + if !is_coro { + return Err(RustPSQLDriverError::ListenerCallbackError) + } + + let task_locals = Python::with_gil(|py| { + pyo3_async_runtimes::tokio::get_current_locals(py)} + )?; + + let listener_callback = ListenerCallback::new( + Some(task_locals), + callback, + ); + + // let awaitable = callback.call1(()).unwrap(); + // println!("8888888 {:?}", awaitable); + // let bbb = pyo3_async_runtimes::tokio::into_future(awaitable).unwrap(); + // println!("999999"); + { + let mut write_channel_callbacks = self.channel_callbacks.write().await; + write_channel_callbacks.add_callback(channel, listener_callback); + } + + self.update_listen_query().await; + Ok(()) } async fn listen(&mut self) -> RustPSQLDriverPyResult<()> { - let Some(client) = self.client.clone() else { + let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::BaseConnectionError("test".into())); }; + let connection = self.connection.clone(); let Some(receiver) = self.receiver.clone() else { return Err(RustPSQLDriverError::BaseConnectionError("test".into())); }; + let listen_query_clone = self.listen_query.clone(); + let is_listened_clone = self.is_listened.clone(); - let jh = tokio_runtime().spawn(async move { - client - .batch_execute( - "LISTEN test_notifications; - LISTEN test_notifications2;", - ) - .await - .unwrap(); + let channel_callbacks = self.channel_callbacks.clone(); + let jh: JoinHandle> = tokio_runtime().spawn(async move { loop { - let mut write_receiver = receiver.write().await; - let next_element = write_receiver.next().await; - client - .batch_execute("LISTEN test_notifications3;") - .await - .unwrap(); - match next_element { - Some(n) => match n { - tokio_postgres::AsyncMessage::Notification(n) => { - println!("Notification {n:?}"); - } - _ => { - println!("in_in {n:?}"); - } - }, - _ => { - println!("in {next_element:?}"); + { + call_listen(&is_listened_clone, &listen_query_clone, &client).await?; + }; + + let next_element = { + let mut write_receiver = receiver.write().await; + write_receiver.next().await + }; + + let inner_notification = process_message(next_element)?; + + let read_channel_callbacks = channel_callbacks.read().await; + let channel = inner_notification.channel.clone(); + let callbacks = read_channel_callbacks.retrieve_channel_callbacks( + channel, + ); + + if let Some(callbacks) = callbacks { + for callback in callbacks { + dispatch_callback( + callback, + inner_notification.clone(), + connection.clone(), + ).await?; } } } @@ -227,3 +402,51 @@ impl Listener { Ok(()) } } + +async fn dispatch_callback( + listener_callback: &ListenerCallback, + listener_notification: ListenerNotification, + connection: Connection, +) -> RustPSQLDriverPyResult<()> { + listener_callback.call( + listener_notification.clone(), + connection, + ).await?; + + Ok(()) +} + +async fn call_listen( + is_listened: &Arc>, + listen_query: &Arc>, + client: &Arc, +) -> RustPSQLDriverPyResult<()> { + let mut write_is_listened = is_listened.write().await; + + if !write_is_listened.eq(&true) { + let listen_q = { + let read_listen_query = listen_query.read().await; + String::from(read_listen_query.as_str()) + }; + + client + .batch_execute(listen_q.as_str()) + .await?; + } + + *write_is_listened = true; + Ok(()) +} + +fn process_message( + message: Option, +) -> RustPSQLDriverPyResult { + let Some(async_message) = message else { + return Err(RustPSQLDriverError::ListenerError("Wow".into())) + }; + let AsyncMessage::Notification(notification) = async_message else { + return Err(RustPSQLDriverError::ListenerError("Wow".into())) + }; + + Ok(ListenerNotification::from(notification)) +} diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index bc992c68..bb429a76 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -1,5 +1,4 @@ use bytes::BytesMut; -use deadpool_postgres::Object; use futures_util::{future, pin_mut}; use pyo3::{ buffer::PyBuffer, @@ -17,8 +16,7 @@ use crate::{ }; use super::{ - cursor::Cursor, - transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, + connection::InnerConnection, cursor::Cursor, transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit} }; use crate::common::ObjectQueryTrait; use std::{collections::HashSet, sync::Arc}; @@ -36,7 +34,7 @@ pub trait TransactionObjectTrait { fn rollback(&self) -> impl std::future::Future> + Send; } -impl TransactionObjectTrait for Object { +impl TransactionObjectTrait for InnerConnection { async fn start_transaction( &self, isolation_level: Option, @@ -106,7 +104,7 @@ impl TransactionObjectTrait for Object { #[pyclass(subclass)] pub struct Transaction { - pub db_client: Option>, + pub db_client: Option>, is_started: bool, is_done: bool, @@ -122,7 +120,7 @@ impl Transaction { #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( - db_client: Arc, + db_client: Arc, is_started: bool, is_done: bool, isolation_level: Option, diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 6502c1dc..c95c3e58 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -3,6 +3,7 @@ use std::{str::FromStr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig}; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; +use pyo3::{types::PyAnyMethods, Py, PyAny, Python}; use tokio_postgres::{Config, NoTls}; use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; @@ -214,3 +215,18 @@ pub fn build_tls( mgr } + +pub fn is_coroutine_function( + function: Py, +) -> RustPSQLDriverPyResult { + let is_coroutine_function: bool = Python::with_gil(|py| { + let inspect = py.import_bound("inspect")?; + + let is_cor = inspect.call_method1("iscoroutinefunction", (function,)).map_err(|_| { + RustPSQLDriverError::ListenerClosedError + })?.extract::()?; + Ok::(is_cor) + })?; + + Ok(is_coroutine_function) +} diff --git a/src/exceptions/python_errors.rs b/src/exceptions/python_errors.rs index bd5fc641..c30ea70d 100644 --- a/src/exceptions/python_errors.rs +++ b/src/exceptions/python_errors.rs @@ -105,6 +105,16 @@ create_exception!(psqlpy.exceptions, CursorCloseError, BaseCursorError); create_exception!(psqlpy.exceptions, CursorFetchError, BaseCursorError); create_exception!(psqlpy.exceptions, CursorClosedError, BaseCursorError); +// Listener Error +create_exception!( + psqlpy.exceptions, + BaseListenerError, + RustPSQLDriverPyBaseError +); +create_exception!(psqlpy.exceptions, ListenerStartError, BaseListenerError); +create_exception!(psqlpy.exceptions, ListenerClosedError, BaseListenerError); +create_exception!(psqlpy.exceptions, ListenerCallbackError, BaseListenerError); + // Inner exceptions create_exception!( psqlpy.exceptions, diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index 72fb659c..428a605d 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -5,13 +5,7 @@ use tokio::task::JoinError; use crate::exceptions::python_errors::{PyToRustValueMappingError, RustToPyValueMappingError}; use super::python_errors::{ - BaseConnectionError, BaseConnectionPoolError, BaseCursorError, BaseTransactionError, - ConnectionClosedError, ConnectionExecuteError, ConnectionPoolBuildError, - ConnectionPoolConfigurationError, ConnectionPoolExecuteError, CursorCloseError, - CursorClosedError, CursorFetchError, CursorStartError, DriverError, MacAddrParseError, - RuntimeJoinError, SSLError, TransactionBeginError, TransactionClosedError, - TransactionCommitError, TransactionExecuteError, TransactionRollbackError, - TransactionSavepointError, UUIDValueConvertError, + BaseConnectionError, BaseConnectionPoolError, BaseCursorError, BaseListenerError, BaseTransactionError, ConnectionClosedError, ConnectionExecuteError, ConnectionPoolBuildError, ConnectionPoolConfigurationError, ConnectionPoolExecuteError, CursorCloseError, CursorClosedError, CursorFetchError, CursorStartError, DriverError, ListenerCallbackError, ListenerClosedError, ListenerStartError, MacAddrParseError, RuntimeJoinError, SSLError, TransactionBeginError, TransactionClosedError, TransactionCommitError, TransactionExecuteError, TransactionRollbackError, TransactionSavepointError, UUIDValueConvertError }; pub type RustPSQLDriverPyResult = Result; @@ -64,6 +58,16 @@ pub enum RustPSQLDriverError { #[error("Underlying connection is returned to the pool")] CursorClosedError, + // Listener Errors + #[error("Listener error: {0}")] + ListenerError(String), + #[error("Listener start error: {0}")] + ListenerStartError(String), + #[error("Underlying connection is returned to the pool")] + ListenerClosedError, + #[error("Callback must be an async callable")] + ListenerCallbackError, + #[error("Can't convert value from driver to python type: {0}")] RustToPyValueConversionError(String), #[error("Can't convert value from python to rust type: {0}")] @@ -161,6 +165,10 @@ impl From for pyo3::PyErr { RustPSQLDriverError::CursorFetchError(_) => CursorFetchError::new_err((error_desc,)), RustPSQLDriverError::SSLError(_) => SSLError::new_err((error_desc,)), RustPSQLDriverError::CursorClosedError => CursorClosedError::new_err((error_desc,)), + RustPSQLDriverError::ListenerError(_) => BaseListenerError::new_err((error_desc,)), + RustPSQLDriverError::ListenerStartError(_) => ListenerStartError::new_err((error_desc,)), + RustPSQLDriverError::ListenerClosedError => ListenerClosedError::new_err((error_desc,)), + RustPSQLDriverError::ListenerCallbackError => ListenerCallbackError::new_err((error_desc,)), } } } diff --git a/src/lib.rs b/src/lib.rs index a535f4c9..5951ea6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; From 2a14188d3f8129a5132d21b0575028482ce74a52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 18 Jan 2025 22:02:53 +0000 Subject: [PATCH 04/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/psqlpy/_internal/__init__.pyi | 1 - src/driver/connection.rs | 8 ++++---- src/driver/listener.rs | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index c05bfc8e..a0a5563b 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1752,6 +1752,5 @@ class ConnectionPoolBuilder: class Listener: """Result.""" - class ListenerNotification: """Result.""" diff --git a/src/driver/connection.rs b/src/driver/connection.rs index b3add461..fb115e10 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -39,7 +39,7 @@ impl InnerConnection { } } } - + pub async fn query( &self, statement: &T, @@ -71,8 +71,8 @@ impl InnerConnection { &self, statement: &T, params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult - where T: ?Sized + ToStatement + ) -> RustPSQLDriverPyResult + where T: ?Sized + ToStatement { match self { InnerConnection::PoolConn(pconn) => { @@ -87,7 +87,7 @@ impl InnerConnection { pub async fn copy_in( &self, statement: &T - ) -> RustPSQLDriverPyResult> + ) -> RustPSQLDriverPyResult> where T: ?Sized + ToStatement, U: Buf + 'static + Send diff --git a/src/driver/listener.rs b/src/driver/listener.rs index 7f9dadb7..71e39a26 100644 --- a/src/driver/listener.rs +++ b/src/driver/listener.rs @@ -107,7 +107,7 @@ impl ListenerCallback { } (self.callback.clone(), None) }); - + if let Some(task_locals) = task_locals { tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { let future = Python::with_gil(|py| { @@ -222,7 +222,7 @@ impl Listener { return Ok(()); } - + return Err(RustPSQLDriverError::ListenerClosedError) } From 143d33354e045f154e40ee9072d410e2a5f1e065 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 19 Jan 2025 01:39:17 +0100 Subject: [PATCH 05/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- python/psqlpy/__init__.py | 4 +- python/psqlpy/_internal/__init__.pyi | 73 +++++++++++++- src/driver/connection_pool.rs | 2 +- src/driver/listener.rs | 143 ++++++++++++++++++--------- src/lib.rs | 2 +- 5 files changed, 173 insertions(+), 51 deletions(-) diff --git a/python/psqlpy/__init__.py b/python/psqlpy/__init__.py index edbdb53b..6f899719 100644 --- a/python/psqlpy/__init__.py +++ b/python/psqlpy/__init__.py @@ -7,7 +7,7 @@ IsolationLevel, KeepaliveConfig, Listener, - ListenerNotification, + ListenerNotificationMsg, LoadBalanceHosts, QueryResult, ReadVariant, @@ -28,7 +28,7 @@ "IsolationLevel", "KeepaliveConfig", "Listener", - "ListenerNotification", + "ListenerNotificationMsg", "LoadBalanceHosts", "QueryResult", "ReadVariant", diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index c05bfc8e..98ff7ed1 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -2,7 +2,7 @@ import types from enum import Enum from io import BytesIO from ipaddress import IPv4Address, IPv6Address -from typing import Any, Callable, Sequence, TypeVar +from typing import Any, Awaitable, Callable, Sequence, TypeVar from typing_extensions import Buffer, Self @@ -1360,6 +1360,9 @@ class ConnectionPool: res = await connection.execute(...) ``` """ + def listener(self: Self) -> Listener: + """Create new listener.""" + def close(self: Self) -> None: """Close the connection pool.""" @@ -1752,6 +1755,70 @@ class ConnectionPoolBuilder: class Listener: """Result.""" + connection: Connection -class ListenerNotification: - """Result.""" + def __aiter__(self: Self) -> Self: ... + async def __anext__(self: Self) -> ListenerNotificationMsg: ... + async def __aenter__(self: Self) -> Self: ... + async def __aexit__( + self: Self, + exception_type: type[BaseException] | None, + exception: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: ... + async def startup(self: Self) -> None: + """Startup the listener. + + Each listener MUST be started up. + """ + async def add_callback( + self: Self, + channel: str, + callback: Callable[ + [str, str, int, Connection], Awaitable[None], + ], + ) -> None: + """Add callback to the channel. + + Callback must be async function and have signature like this: + ```python + async def callback( + channel: str, + payload: str, + process_id: str, + connection: Connection, + ) -> None: + ... + ``` + """ + + async def clear_channel_callbacks(self, channel: str) -> None: + """Remove all callbacks for the channel. + + ### Parameters: + - `channel`: name of the channel. + """ + + async def listen(self: Self) -> None: + """Start listening. + + Start actual listening. + In the background it creates task in Rust event loop. + You must save returned Future to the array. + """ + + async def abort_listen(self: Self) -> None: + """Abort listen. + + If `listen()` method was called, stop listening, + else don't do anything. + """ + + +class ListenerNotificationMsg: + """Listener message in async iterator.""" + + process_id: int + channel: str + payload: str + connection: Connection diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 6d090c64..fabe789e 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -500,7 +500,7 @@ impl ConnectionPool { Connection::new(None, Some(self.pool.clone())) } - pub async fn add_listener( + pub fn listener( self_: pyo3::Py, ) -> RustPSQLDriverPyResult { let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { diff --git a/src/driver/listener.rs b/src/driver/listener.rs index 7f9dadb7..b7835fe9 100644 --- a/src/driver/listener.rs +++ b/src/driver/listener.rs @@ -55,11 +55,10 @@ impl ChannelCallbacks { #[derive(Clone, Debug)] -#[pyclass] pub struct ListenerNotification { - process_id: i32, - channel: String, - payload: String, + pub process_id: i32, + pub channel: String, + pub payload: String, } impl From:: for ListenerNotification { @@ -73,13 +72,47 @@ impl From:: for ListenerNotification { } #[pyclass] -struct ListenerNotificationMsg { +pub struct ListenerNotificationMsg { process_id: i32, channel: String, payload: String, connection: Connection, } +#[pymethods] +impl ListenerNotificationMsg { + #[getter] + fn process_id(&self) -> i32 { + self.process_id + } + + #[getter] + fn channel(&self) -> String { + self.channel.clone() + } + + #[getter] + fn payload(&self) -> String { + self.payload.clone() + } + + #[getter] + fn connection(&self) -> Connection { + self.connection.clone() + } +} + +impl ListenerNotificationMsg { + fn new(value: ListenerNotification, conn: Connection) -> Self { + ListenerNotificationMsg { + process_id: value.process_id, + channel: String::from(value.channel), + payload: String::from(value.payload), + connection: conn, + } + } +} + struct ListenerCallback { task_locals: Option, callback: Py, @@ -111,7 +144,15 @@ impl ListenerCallback { if let Some(task_locals) = task_locals { tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { let future = Python::with_gil(|py| { - let awaitable = callback.call1(py, (lister_notification, connection)).unwrap(); + let awaitable = callback.call1( + py, + ( + lister_notification.channel, + lister_notification.payload, + lister_notification.process_id, + connection, + ) + ).unwrap(); pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap() }); future.await.unwrap(); @@ -226,6 +267,41 @@ impl Listener { return Err(RustPSQLDriverError::ListenerClosedError) } + fn __anext__(&self) -> RustPSQLDriverPyResult>> { + let Some(client) = self.connection.db_client() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener doesn't have underlying client, please call startup".into(), + )); + }; + let Some(receiver) = self.receiver.clone() else { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener doesn't have underlying receiver, please call startup".into(), + )); + }; + + let is_listened_clone = self.is_listened.clone(); + let listen_query_clone = self.listen_query.clone(); + let connection = self.connection.clone(); + + let py_future = Python::with_gil(move |gil| { + rustdriver_future(gil, async move { + { + call_listen(&is_listened_clone, &listen_query_clone, &client).await?; + }; + let next_element = { + let mut write_receiver = receiver.write().await; + write_receiver.next().await + }; + + let inner_notification = process_message(next_element)?; + + Ok(ListenerNotificationMsg::new(inner_notification, connection)) + }) + }); + + Ok(Some(py_future?)) + } + #[getter] fn connection(&self) -> Connection { self.connection.clone() @@ -280,40 +356,6 @@ impl Listener { Ok(()) } - fn __anext__(&self) -> RustPSQLDriverPyResult> { - let Some(client) = self.connection.db_client() else { - return Err(RustPSQLDriverError::ListenerStartError( - "Listener doesn't have underlying client, please call startup".into(), - )); - }; - let Some(receiver) = self.receiver.clone() else { - return Err(RustPSQLDriverError::ListenerStartError( - "Listener doesn't have underlying receiver, please call startup".into(), - )); - }; - - let is_listened_clone = self.is_listened.clone(); - let listen_query_clone = self.listen_query.clone(); - - let py_future = Python::with_gil(move |gil| { - rustdriver_future(gil, async move { - { - call_listen(&is_listened_clone, &listen_query_clone, &client).await?; - }; - let next_element = { - let mut write_receiver = receiver.write().await; - write_receiver.next().await - }; - - let inner_notification = process_message(next_element)?; - - Ok(inner_notification) - }) - }); - - Ok(Some(py_future?)) - } - #[pyo3(signature = (channel, callback))] async fn add_callback( &mut self, @@ -337,10 +379,6 @@ impl Listener { callback, ); - // let awaitable = callback.call1(()).unwrap(); - // println!("8888888 {:?}", awaitable); - // let bbb = pyo3_async_runtimes::tokio::into_future(awaitable).unwrap(); - // println!("999999"); { let mut write_channel_callbacks = self.channel_callbacks.write().await; write_channel_callbacks.add_callback(channel, listener_callback); @@ -351,6 +389,15 @@ impl Listener { Ok(()) } + async fn clear_channel_callbacks(&mut self, channel: String) { + { + let mut write_channel_callbacks = self.channel_callbacks.write().await; + write_channel_callbacks.clear_channel_callbacks(channel); + } + + self.update_listen_query().await; + } + async fn listen(&mut self) -> RustPSQLDriverPyResult<()> { let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::BaseConnectionError("test".into())); @@ -401,6 +448,14 @@ impl Listener { Ok(()) } + + async fn abort_listen(&mut self) { + if let Some(listen_abort_handler) = &self.listen_abort_handler { + listen_abort_handler.abort(); + } + + self.listen_abort_handler = None; + } } async fn dispatch_callback( diff --git a/src/lib.rs b/src/lib.rs index 5951ea6d..8175546e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; - pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; From 324a7663a7308b8122762da6774db7f799423eb5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Jan 2025 00:59:58 +0000 Subject: [PATCH 06/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/psqlpy/_internal/__init__.pyi | 21 ++++++++++----------- src/driver/listener/core.rs | 2 +- src/driver/listener/structs.rs | 8 ++++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 98ff7ed1..6bfed7e0 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1775,11 +1775,12 @@ class Listener: self: Self, channel: str, callback: Callable[ - [str, str, int, Connection], Awaitable[None], + [str, str, int, Connection], + Awaitable[None], ], ) -> None: """Add callback to the channel. - + Callback must be async function and have signature like this: ```python async def callback( @@ -1787,14 +1788,13 @@ class Listener: payload: str, process_id: str, connection: Connection, - ) -> None: - ... + ) -> None: ... ``` """ - + async def clear_channel_callbacks(self, channel: str) -> None: """Remove all callbacks for the channel. - + ### Parameters: - `channel`: name of the channel. """ @@ -1802,19 +1802,18 @@ class Listener: async def listen(self: Self) -> None: """Start listening. - Start actual listening. - In the background it creates task in Rust event loop. + Start actual listening. + In the background it creates task in Rust event loop. You must save returned Future to the array. """ async def abort_listen(self: Self) -> None: """Abort listen. - - If `listen()` method was called, stop listening, + + If `listen()` method was called, stop listening, else don't do anything. """ - class ListenerNotificationMsg: """Listener message in async iterator.""" diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 3c48547e..1a661cb1 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -115,7 +115,7 @@ impl Listener { return Ok(()); } - + return Err(RustPSQLDriverError::ListenerClosedError) } diff --git a/src/driver/listener/structs.rs b/src/driver/listener/structs.rs index 5bd25365..72fcb00f 100644 --- a/src/driver/listener/structs.rs +++ b/src/driver/listener/structs.rs @@ -74,17 +74,17 @@ impl ListenerNotificationMsg { fn process_id(&self) -> i32 { self.process_id } - + #[getter] fn channel(&self) -> String { self.channel.clone() } - + #[getter] fn payload(&self) -> String { self.payload.clone() } - + #[getter] fn connection(&self) -> Connection { self.connection.clone() @@ -129,7 +129,7 @@ impl ListenerCallback { } (self.callback.clone(), None) }); - + if let Some(task_locals) = task_locals { tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { let future = Python::with_gil(|py| { From bcd96f3fa8a4aa88b956db1a36c525f65fd0c260 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 19 Jan 2025 02:01:19 +0100 Subject: [PATCH 07/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- python/psqlpy/_internal/__init__.pyi | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 98ff7ed1..6bfed7e0 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1775,11 +1775,12 @@ class Listener: self: Self, channel: str, callback: Callable[ - [str, str, int, Connection], Awaitable[None], + [str, str, int, Connection], + Awaitable[None], ], ) -> None: """Add callback to the channel. - + Callback must be async function and have signature like this: ```python async def callback( @@ -1787,14 +1788,13 @@ class Listener: payload: str, process_id: str, connection: Connection, - ) -> None: - ... + ) -> None: ... ``` """ - + async def clear_channel_callbacks(self, channel: str) -> None: """Remove all callbacks for the channel. - + ### Parameters: - `channel`: name of the channel. """ @@ -1802,19 +1802,18 @@ class Listener: async def listen(self: Self) -> None: """Start listening. - Start actual listening. - In the background it creates task in Rust event loop. + Start actual listening. + In the background it creates task in Rust event loop. You must save returned Future to the array. """ async def abort_listen(self: Self) -> None: """Abort listen. - - If `listen()` method was called, stop listening, + + If `listen()` method was called, stop listening, else don't do anything. """ - class ListenerNotificationMsg: """Listener message in async iterator.""" From 9e1118eec65392e1232374d06044ecbb13ef549d Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 19 Jan 2025 12:56:58 +0100 Subject: [PATCH 08/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- python/psqlpy/_internal/__init__.pyi | 4 +- src/common.rs | 9 +- src/driver/connection.rs | 98 ++++++++++++---------- src/driver/connection_pool.rs | 21 +++-- src/driver/cursor.rs | 8 +- src/driver/listener/core.rs | 121 ++++++++++++--------------- src/driver/listener/mod.rs | 2 +- src/driver/listener/structs.rs | 95 +++++++++++---------- src/driver/transaction.rs | 12 +-- src/driver/utils.rs | 35 ++++---- src/exceptions/rust_errors.rs | 16 +++- 11 files changed, 225 insertions(+), 196 deletions(-) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 6bfed7e0..e348d771 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1799,7 +1799,7 @@ class Listener: - `channel`: name of the channel. """ - async def listen(self: Self) -> None: + def listen(self: Self) -> None: """Start listening. Start actual listening. @@ -1807,7 +1807,7 @@ class Listener: You must save returned Future to the array. """ - async def abort_listen(self: Self) -> None: + def abort_listen(self: Self) -> None: """Abort listen. If `listen()` method was called, stop listening, diff --git a/src/common.rs b/src/common.rs index f4772860..aa98a6e2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -4,7 +4,10 @@ use pyo3::{ }; use crate::{ - driver::connection::InnerConnection, exceptions::rust_errors::RustPSQLDriverPyResult, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, value_converter::{convert_parameters, PythonDTO, QueryParameter} + driver::connection::PsqlpyConnection, + exceptions::rust_errors::RustPSQLDriverPyResult, + query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, + value_converter::{convert_parameters, PythonDTO, QueryParameter}, }; /// Add new module to the parent one. @@ -52,7 +55,7 @@ pub trait ObjectQueryTrait { ) -> impl std::future::Future> + Send; } -impl ObjectQueryTrait for InnerConnection { +impl ObjectQueryTrait for PsqlpyConnection { async fn psqlpy_query_one( &self, querystring: String, @@ -128,6 +131,6 @@ impl ObjectQueryTrait for InnerConnection { } async fn psqlpy_query_simple(&self, querystring: String) -> RustPSQLDriverPyResult<()> { - Ok(self.batch_execute(querystring.as_str()).await?) + self.batch_execute(querystring.as_str()).await } } diff --git a/src/driver/connection.rs b/src/driver/connection.rs index fb115e10..f9719c8d 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -4,7 +4,9 @@ use futures_util::pin_mut; use postgres_types::ToSql; use pyo3::{buffer::PyBuffer, pyclass, pymethods, Py, PyAny, PyErr, Python}; use std::{collections::HashSet, sync::Arc, vec}; -use tokio_postgres::{binary_copy::BinaryCopyInWriter, Client, CopyInSink, Row, Statement, ToStatement}; +use tokio_postgres::{ + binary_copy::BinaryCopyInWriter, Client, CopyInSink, Row, Statement, ToStatement, +}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, @@ -20,85 +22,89 @@ use super::{ transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; -pub enum InnerConnection { +#[allow(clippy::module_name_repetitions)] +pub enum PsqlpyConnection { PoolConn(Object), SingleConn(Client), } -impl InnerConnection { - pub async fn prepare_cached( - &self, - query: &str - ) -> RustPSQLDriverPyResult { +impl PsqlpyConnection { + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot prepare statement. + pub async fn prepare_cached(&self, query: &str) -> RustPSQLDriverPyResult { match self { - InnerConnection::PoolConn(pconn) => { - return Ok(pconn.prepare_cached(query).await?) - } - InnerConnection::SingleConn(sconn) => { - return Ok(sconn.prepare(query).await?) - } + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), } } + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. pub async fn query( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> RustPSQLDriverPyResult> - where T: ?Sized + ToStatement { + where + T: ?Sized + ToStatement, + { match self { - InnerConnection::PoolConn(pconn) => { - return Ok(pconn.query(statement, params).await?) - } - InnerConnection::SingleConn(sconn) => { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.query(statement, params).await?), + PsqlpyConnection::SingleConn(sconn) => { return Ok(sconn.query(statement, params).await?) } } } + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { match self { - InnerConnection::PoolConn(pconn) => { - return Ok(pconn.batch_execute(query).await?) - } - InnerConnection::SingleConn(sconn) => { - return Ok(sconn.batch_execute(query).await?) - } + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), } } + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute statement. pub async fn query_one( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> RustPSQLDriverPyResult - where T: ?Sized + ToStatement + where + T: ?Sized + ToStatement, { match self { - InnerConnection::PoolConn(pconn) => { + PsqlpyConnection::PoolConn(pconn) => { return Ok(pconn.query_one(statement, params).await?) } - InnerConnection::SingleConn(sconn) => { + PsqlpyConnection::SingleConn(sconn) => { return Ok(sconn.query_one(statement, params).await?) } } } - pub async fn copy_in( - &self, - statement: &T - ) -> RustPSQLDriverPyResult> + /// Prepare cached statement. + /// + /// # Errors + /// May return Err if cannot execute copy data. + pub async fn copy_in(&self, statement: &T) -> RustPSQLDriverPyResult> where T: ?Sized + ToStatement, - U: Buf + 'static + Send + U: Buf + 'static + Send, { match self { - InnerConnection::PoolConn(pconn) => { - return Ok(pconn.copy_in(statement).await?) - } - InnerConnection::SingleConn(sconn) => { - return Ok(sconn.copy_in(statement).await?) - } + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.copy_in(statement).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), } } } @@ -106,22 +112,24 @@ impl InnerConnection { #[pyclass(subclass)] #[derive(Clone)] pub struct Connection { - db_client: Option>, + db_client: Option>, db_pool: Option, } impl Connection { #[must_use] - pub fn new(db_client: Option>, db_pool: Option) -> Self { + pub fn new(db_client: Option>, db_pool: Option) -> Self { Connection { db_client, db_pool } } - pub fn db_client(&self) -> Option> { - return self.db_client.clone() + #[must_use] + pub fn db_client(&self) -> Option> { + self.db_client.clone() } + #[must_use] pub fn db_pool(&self) -> Option { - return self.db_pool.clone() + self.db_pool.clone() } } @@ -151,7 +159,7 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(InnerConnection::PoolConn(db_connection))); + self_.db_client = Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))); }); return Ok(self_); } @@ -277,7 +285,7 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return Ok(db_client.batch_execute(&querystring).await?); + return db_client.batch_execute(&querystring).await; } Err(RustPSQLDriverError::ConnectionClosedError) diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index a37d15b9..e4e2e953 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -11,7 +11,10 @@ use crate::{ }; use super::{ - common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, connection::{Connection, InnerConnection}, listener::core::Listener, utils::{build_connection_config, build_manager, build_tls} + common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, + connection::{Connection, PsqlpyConnection}, + listener::core::Listener, + utils::{build_connection_config, build_manager, build_tls}, }; /// Make new connection pool. @@ -210,7 +213,8 @@ pub struct ConnectionPool { } impl ConnectionPool { - #[must_use] pub fn build( + #[must_use] + pub fn build( pool: Pool, pg_config: Config, ca_file: Option, @@ -497,9 +501,9 @@ impl ConnectionPool { Connection::new(None, Some(self.pool.clone())) } - pub fn listener( - self_: pyo3::Py, - ) -> RustPSQLDriverPyResult { + #[must_use] + #[allow(clippy::needless_pass_by_value)] + pub fn listener(self_: pyo3::Py) -> Listener { let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { let b_gil = self_.borrow(gil); ( @@ -509,7 +513,7 @@ impl ConnectionPool { ) }); - Ok(Listener::new(pg_config, ca_file, ssl_mode)) + Listener::new(pg_config, ca_file, ssl_mode) } /// Return new single connection. @@ -524,7 +528,10 @@ impl ConnectionPool { }) .await??; - Ok(Connection::new(Some(Arc::new(InnerConnection::PoolConn(db_connection))), None)) + Ok(Connection::new( + Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))), + None, + )) } /// Close connection pool. diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 3284bc19..89fa7428 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -11,7 +11,7 @@ use crate::{ runtime::rustdriver_future, }; -use super::connection::InnerConnection; +use super::connection::PsqlpyConnection; /// Additional implementation for the `Object` type. #[allow(clippy::ref_option)] @@ -28,7 +28,7 @@ trait CursorObjectTrait { async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; } -impl CursorObjectTrait for InnerConnection { +impl CursorObjectTrait for PsqlpyConnection { /// Start the cursor. /// /// Execute `DECLARE` command with parameters. @@ -90,7 +90,7 @@ impl CursorObjectTrait for InnerConnection { #[pyclass(subclass)] pub struct Cursor { - db_transaction: Option>, + db_transaction: Option>, querystring: String, parameters: Option>, cursor_name: String, @@ -104,7 +104,7 @@ pub struct Cursor { impl Cursor { #[must_use] pub fn new( - db_transaction: Arc, + db_transaction: Arc, querystring: String, parameters: Option>, cursor_name: String, diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 3c48547e..6756eadf 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -5,15 +5,25 @@ use futures_channel::mpsc::UnboundedReceiver; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, Python}; -use tokio::{sync::RwLock, task::{AbortHandle, JoinHandle}}; +use tokio::{ + sync::RwLock, + task::{AbortHandle, JoinHandle}, +}; use tokio_postgres::{AsyncMessage, Config}; use crate::{ - driver::{common_options::SslMode, connection::{Connection, InnerConnection}, utils::{build_tls, is_coroutine_function, ConfiguredTLS}}, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, runtime::{rustdriver_future, tokio_runtime} + driver::{ + common_options::SslMode, + connection::{Connection, PsqlpyConnection}, + utils::{build_tls, is_coroutine_function, ConfiguredTLS}, + }, + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + runtime::{rustdriver_future, tokio_runtime}, }; -use super::structs::{ChannelCallbacks, ListenerCallback, ListenerNotification, ListenerNotificationMsg}; - +use super::structs::{ + ChannelCallbacks, ListenerCallback, ListenerNotification, ListenerNotificationMsg, +}; #[pyclass] pub struct Listener { @@ -30,26 +40,23 @@ pub struct Listener { } impl Listener { - #[must_use] pub fn new( - pg_config: Config, - ca_file: Option, - ssl_mode: Option, - ) -> Self { + #[must_use] + pub fn new(pg_config: Config, ca_file: Option, ssl_mode: Option) -> Self { Listener { pg_config, ca_file, ssl_mode, - channel_callbacks: Default::default(), - listen_abort_handler: Default::default(), + channel_callbacks: Arc::default(), + listen_abort_handler: Option::default(), connection: Connection::new(None, None), - receiver: Default::default(), - listen_query: Default::default(), + receiver: Option::default(), + listen_query: Arc::default(), is_listened: Arc::new(RwLock::new(false)), is_started: false, } } - async fn update_listen_query(&self) -> () { + async fn update_listen_query(&self) { let read_channel_callbacks = self.channel_callbacks.read().await; let channels = read_channel_callbacks.retrieve_all_channels(); @@ -57,9 +64,7 @@ impl Listener { let mut final_query: String = String::default(); for channel_name in channels { - final_query.push_str( - format!("LISTEN {};", channel_name).as_str() - ); + final_query.push_str(format!("LISTEN {channel_name};").as_str()); } let mut write_listen_query = self.listen_query.write().await; @@ -82,25 +87,26 @@ impl Listener { slf } + #[allow(clippy::unused_async)] async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { Ok(slf) } + #[allow(clippy::unused_async)] async fn __aexit__<'a>( slf: Py, _exception_type: Py, exception: Py, _traceback: Py, ) -> RustPSQLDriverPyResult<()> { - let (client, is_exception_none, py_err) = - pyo3::Python::with_gil(|gil| { - let self_ = slf.borrow(gil); - ( - self_.connection.db_client(), - exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), - ) - }); + let (client, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { + let self_ = slf.borrow(gil); + ( + self_.connection.db_client(), + exception.is_none(gil), + PyErr::from_value_bound(exception.into_bound(gil)), + ) + }); if client.is_some() { pyo3::Python::with_gil(|gil| { @@ -115,8 +121,8 @@ impl Listener { return Ok(()); } - - return Err(RustPSQLDriverError::ListenerClosedError) + + Err(RustPSQLDriverError::ListenerClosedError) } fn __anext__(&self) -> RustPSQLDriverPyResult>> { @@ -190,18 +196,14 @@ impl Listener { let (transmitter, receiver) = futures_channel::mpsc::unbounded::(); let stream = - stream::poll_fn( - move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e), - ); + stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); let connection = stream.forward(transmitter).map(|r| r.unwrap()); tokio_runtime().spawn(connection); self.receiver = Some(Arc::new(RwLock::new(receiver))); - self.connection = Connection::new( - Some(Arc::new(InnerConnection::SingleConn(client))), - None, - ); + self.connection = + Connection::new(Some(Arc::new(PsqlpyConnection::SingleConn(client))), None); self.is_started = true; @@ -219,17 +221,12 @@ impl Listener { let is_coro = is_coroutine_function(callback_clone)?; if !is_coro { - return Err(RustPSQLDriverError::ListenerCallbackError) + return Err(RustPSQLDriverError::ListenerCallbackError); } - let task_locals = Python::with_gil(|py| { - pyo3_async_runtimes::tokio::get_current_locals(py)} - )?; + let task_locals = Python::with_gil(pyo3_async_runtimes::tokio::get_current_locals)?; - let listener_callback = ListenerCallback::new( - Some(task_locals), - callback, - ); + let listener_callback = ListenerCallback::new(task_locals, callback); { let mut write_channel_callbacks = self.channel_callbacks.write().await; @@ -244,13 +241,13 @@ impl Listener { async fn clear_channel_callbacks(&mut self, channel: String) { { let mut write_channel_callbacks = self.channel_callbacks.write().await; - write_channel_callbacks.clear_channel_callbacks(channel); + write_channel_callbacks.clear_channel_callbacks(&channel); } self.update_listen_query().await; } - async fn listen(&mut self) -> RustPSQLDriverPyResult<()> { + fn listen(&mut self) -> RustPSQLDriverPyResult<()> { let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::BaseConnectionError("test".into())); }; @@ -278,17 +275,12 @@ impl Listener { let read_channel_callbacks = channel_callbacks.read().await; let channel = inner_notification.channel.clone(); - let callbacks = read_channel_callbacks.retrieve_channel_callbacks( - channel, - ); + let callbacks = read_channel_callbacks.retrieve_channel_callbacks(&channel); if let Some(callbacks) = callbacks { for callback in callbacks { - dispatch_callback( - callback, - inner_notification.clone(), - connection.clone(), - ).await?; + dispatch_callback(callback, inner_notification.clone(), connection.clone()) + .await?; } } } @@ -301,7 +293,7 @@ impl Listener { Ok(()) } - async fn abort_listen(&mut self) { + fn abort_listen(&mut self) { if let Some(listen_abort_handler) = &self.listen_abort_handler { listen_abort_handler.abort(); } @@ -315,10 +307,9 @@ async fn dispatch_callback( listener_notification: ListenerNotification, connection: Connection, ) -> RustPSQLDriverPyResult<()> { - listener_callback.call( - listener_notification.clone(), - connection, - ).await?; + listener_callback + .call(listener_notification.clone(), connection) + .await?; Ok(()) } @@ -326,7 +317,7 @@ async fn dispatch_callback( async fn call_listen( is_listened: &Arc>, listen_query: &Arc>, - client: &Arc, + client: &Arc, ) -> RustPSQLDriverPyResult<()> { let mut write_is_listened = is_listened.write().await; @@ -336,23 +327,19 @@ async fn call_listen( String::from(read_listen_query.as_str()) }; - client - .batch_execute(listen_q.as_str()) - .await?; + client.batch_execute(listen_q.as_str()).await?; } *write_is_listened = true; Ok(()) } -fn process_message( - message: Option, -) -> RustPSQLDriverPyResult { +fn process_message(message: Option) -> RustPSQLDriverPyResult { let Some(async_message) = message else { - return Err(RustPSQLDriverError::ListenerError("Wow".into())) + return Err(RustPSQLDriverError::ListenerError("Wow".into())); }; let AsyncMessage::Notification(notification) = async_message else { - return Err(RustPSQLDriverError::ListenerError("Wow".into())) + return Err(RustPSQLDriverError::ListenerError("Wow".into())); }; Ok(ListenerNotification::from(notification)) diff --git a/src/driver/listener/mod.rs b/src/driver/listener/mod.rs index ad8eb3bc..b67b7b0b 100644 --- a/src/driver/listener/mod.rs +++ b/src/driver/listener/mod.rs @@ -1,2 +1,2 @@ +pub mod core; pub mod structs; -pub mod core; \ No newline at end of file diff --git a/src/driver/listener/structs.rs b/src/driver/listener/structs.rs index 5bd25365..d041258c 100644 --- a/src/driver/listener/structs.rs +++ b/src/driver/listener/structs.rs @@ -5,18 +5,14 @@ use pyo3_async_runtimes::TaskLocals; use tokio_postgres::Notification; use crate::{ - driver::connection::Connection, exceptions::rust_errors::RustPSQLDriverPyResult, runtime::tokio_runtime + driver::connection::Connection, + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + runtime::tokio_runtime, }; - +#[derive(Default)] pub struct ChannelCallbacks(HashMap>); -impl Default for ChannelCallbacks { - fn default() -> Self { - ChannelCallbacks(Default::default()) - } -} - impl ChannelCallbacks { pub fn add_callback(&mut self, channel: String, callback: ListenerCallback) { match self.0.entry(channel) { @@ -29,20 +25,21 @@ impl ChannelCallbacks { }; } - pub fn retrieve_channel_callbacks(&self, channel: String) -> Option<&Vec> { - self.0.get(&channel) + #[must_use] + pub fn retrieve_channel_callbacks(&self, channel: &str) -> Option<&Vec> { + self.0.get(channel) } - pub fn clear_channel_callbacks(&mut self, channel: String) { - self.0.remove(&channel); + pub fn clear_channel_callbacks(&mut self, channel: &str) { + self.0.remove(channel); } + #[must_use] pub fn retrieve_all_channels(&self) -> Vec<&String> { self.0.keys().collect::>() } } - #[derive(Clone, Debug)] pub struct ListenerNotification { pub process_id: i32, @@ -50,7 +47,7 @@ pub struct ListenerNotification { pub payload: String, } -impl From:: for ListenerNotification { +impl From for ListenerNotification { fn from(value: Notification) -> Self { ListenerNotification { process_id: value.process_id(), @@ -74,17 +71,17 @@ impl ListenerNotificationMsg { fn process_id(&self) -> i32 { self.process_id } - + #[getter] fn channel(&self) -> String { self.channel.clone() } - + #[getter] fn payload(&self) -> String { self.payload.clone() } - + #[getter] fn connection(&self) -> Connection { self.connection.clone() @@ -92,61 +89,69 @@ impl ListenerNotificationMsg { } impl ListenerNotificationMsg { + #[must_use] pub fn new(value: ListenerNotification, conn: Connection) -> Self { ListenerNotificationMsg { process_id: value.process_id, - channel: String::from(value.channel), - payload: String::from(value.payload), + channel: value.channel, + payload: value.payload, connection: conn, } } } pub struct ListenerCallback { - task_locals: Option, + task_locals: TaskLocals, callback: Py, } impl ListenerCallback { - pub fn new( - task_locals: Option, - callback: Py, - ) -> Self { + #[must_use] + pub fn new(task_locals: TaskLocals, callback: Py) -> Self { ListenerCallback { task_locals, callback, } } + /// Dispatch the callback. + /// + /// # Errors + /// May return Err Result if cannot call python future. pub async fn call( &self, lister_notification: ListenerNotification, connection: Connection, ) -> RustPSQLDriverPyResult<()> { - let (callback, task_locals) = Python::with_gil(|py| { - if let Some(task_locals) = &self.task_locals { - return (self.callback.clone(), Some(task_locals.clone_ref(py))); - } - (self.callback.clone(), None) - }); - - if let Some(task_locals) = task_locals { - tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { + let (callback, task_locals) = + Python::with_gil(|py| (self.callback.clone(), self.task_locals.clone_ref(py))); + + tokio_runtime() + .spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { let future = Python::with_gil(|py| { - let awaitable = callback.call1( - py, - ( - lister_notification.channel, - lister_notification.payload, - lister_notification.process_id, - connection, + let awaitable = callback + .call1( + py, + ( + lister_notification.channel, + lister_notification.payload, + lister_notification.process_id, + connection, + ), ) - ).unwrap(); - pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap() + .map_err(|_| RustPSQLDriverError::ListenerCallbackError)?; + let aba = pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py))?; + Ok(aba) }); - future.await.unwrap(); - })).await?; - }; + Ok::, RustPSQLDriverError>( + future + .map_err(|_: RustPSQLDriverError| { + RustPSQLDriverError::ListenerCallbackError + })? + .await?, + ) + })) + .await??; Ok(()) } diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index bb429a76..b7ffe57d 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -16,7 +16,9 @@ use crate::{ }; use super::{ - connection::InnerConnection, cursor::Cursor, transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit} + connection::PsqlpyConnection, + cursor::Cursor, + transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, }; use crate::common::ObjectQueryTrait; use std::{collections::HashSet, sync::Arc}; @@ -34,7 +36,7 @@ pub trait TransactionObjectTrait { fn rollback(&self) -> impl std::future::Future> + Send; } -impl TransactionObjectTrait for InnerConnection { +impl TransactionObjectTrait for PsqlpyConnection { async fn start_transaction( &self, isolation_level: Option, @@ -104,7 +106,7 @@ impl TransactionObjectTrait for InnerConnection { #[pyclass(subclass)] pub struct Transaction { - pub db_client: Option>, + pub db_client: Option>, is_started: bool, is_done: bool, @@ -120,7 +122,7 @@ impl Transaction { #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( - db_client: Arc, + db_client: Arc, is_started: bool, is_done: bool, isolation_level: Option, @@ -353,7 +355,7 @@ impl Transaction { }); is_transaction_ready?; if let Some(db_client) = db_client { - return Ok(db_client.batch_execute(&querystring).await?); + return db_client.batch_execute(&querystring).await; } Err(RustPSQLDriverError::TransactionClosedError) diff --git a/src/driver/utils.rs b/src/driver/utils.rs index c95c3e58..c7f8d10b 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -175,6 +175,10 @@ pub enum ConfiguredTLS { TlsConnector(MakeTlsConnector), } +/// Create TLS. +/// +/// # Errors +/// May return Err Result if cannot create builder. pub fn build_tls( ca_file: &Option, ssl_mode: Option, @@ -198,33 +202,36 @@ pub fn build_tls( Ok(ConfiguredTLS::NoTls) } -#[must_use] pub fn build_manager( +#[must_use] +pub fn build_manager( mgr_config: ManagerConfig, pg_config: Config, configured_tls: ConfiguredTLS, ) -> Manager { - let mgr: Manager; - match configured_tls { - ConfiguredTLS::NoTls => { - mgr = Manager::from_config(pg_config, NoTls, mgr_config); - } + let mgr: Manager = match configured_tls { + ConfiguredTLS::NoTls => Manager::from_config(pg_config, NoTls, mgr_config), ConfiguredTLS::TlsConnector(connector) => { - mgr = Manager::from_config(pg_config, connector, mgr_config); + Manager::from_config(pg_config, connector, mgr_config) } - } + }; mgr } -pub fn is_coroutine_function( - function: Py, -) -> RustPSQLDriverPyResult { +/// Check is python object async or not. +/// +/// # Errors +/// May return Err Result if cannot +/// 1) import inspect +/// 2) extract boolean +pub fn is_coroutine_function(function: Py) -> RustPSQLDriverPyResult { let is_coroutine_function: bool = Python::with_gil(|py| { let inspect = py.import_bound("inspect")?; - let is_cor = inspect.call_method1("iscoroutinefunction", (function,)).map_err(|_| { - RustPSQLDriverError::ListenerClosedError - })?.extract::()?; + let is_cor = inspect + .call_method1("iscoroutinefunction", (function,)) + .map_err(|_| RustPSQLDriverError::ListenerClosedError)? + .extract::()?; Ok::(is_cor) })?; diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index 428a605d..48af50cb 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -5,7 +5,13 @@ use tokio::task::JoinError; use crate::exceptions::python_errors::{PyToRustValueMappingError, RustToPyValueMappingError}; use super::python_errors::{ - BaseConnectionError, BaseConnectionPoolError, BaseCursorError, BaseListenerError, BaseTransactionError, ConnectionClosedError, ConnectionExecuteError, ConnectionPoolBuildError, ConnectionPoolConfigurationError, ConnectionPoolExecuteError, CursorCloseError, CursorClosedError, CursorFetchError, CursorStartError, DriverError, ListenerCallbackError, ListenerClosedError, ListenerStartError, MacAddrParseError, RuntimeJoinError, SSLError, TransactionBeginError, TransactionClosedError, TransactionCommitError, TransactionExecuteError, TransactionRollbackError, TransactionSavepointError, UUIDValueConvertError + BaseConnectionError, BaseConnectionPoolError, BaseCursorError, BaseListenerError, + BaseTransactionError, ConnectionClosedError, ConnectionExecuteError, ConnectionPoolBuildError, + ConnectionPoolConfigurationError, ConnectionPoolExecuteError, CursorCloseError, + CursorClosedError, CursorFetchError, CursorStartError, DriverError, ListenerCallbackError, + ListenerClosedError, ListenerStartError, MacAddrParseError, RuntimeJoinError, SSLError, + TransactionBeginError, TransactionClosedError, TransactionCommitError, TransactionExecuteError, + TransactionRollbackError, TransactionSavepointError, UUIDValueConvertError, }; pub type RustPSQLDriverPyResult = Result; @@ -166,9 +172,13 @@ impl From for pyo3::PyErr { RustPSQLDriverError::SSLError(_) => SSLError::new_err((error_desc,)), RustPSQLDriverError::CursorClosedError => CursorClosedError::new_err((error_desc,)), RustPSQLDriverError::ListenerError(_) => BaseListenerError::new_err((error_desc,)), - RustPSQLDriverError::ListenerStartError(_) => ListenerStartError::new_err((error_desc,)), + RustPSQLDriverError::ListenerStartError(_) => { + ListenerStartError::new_err((error_desc,)) + } RustPSQLDriverError::ListenerClosedError => ListenerClosedError::new_err((error_desc,)), - RustPSQLDriverError::ListenerCallbackError => ListenerCallbackError::new_err((error_desc,)), + RustPSQLDriverError::ListenerCallbackError => { + ListenerCallbackError::new_err((error_desc,)) + } } } } From 0aabe17cfe86678d29adcba18d46be6296bc74bc Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 20 Jan 2025 09:59:06 +0100 Subject: [PATCH 09/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- python/psqlpy/_internal/__init__.pyi | 55 +++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index e348d771..04335995 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1753,7 +1753,60 @@ class ConnectionPoolBuilder: """ class Listener: - """Result.""" + """Listener for LISTEN command. + + Can be used two ways: + 1) As a background task + 2) As an asynchronous iterator + + ## Examples + + ### Background task: + + ```python + async def callback( + channel: str, + payload: str, + process_id: int, + connection: Connection, + ) -> None: ... + async def main(): + pool = ConnectionPool() + + listener = pool.listener() + await listener.add_callback( + channel="test_channel", + callback=callback, + ) + await listener.startup() + + listener.listen() + ``` + + ### Async iterator + ```python + from psqlpy import + + async def msg_processor( + msg: ListenerNotificationMsg, + ) -> None: + ... + + + async def main(): + pool = ConnectionPool() + + listener = pool.listener() + await listener.add_callback( + channel="test_channel", + callback=callback, + ) + await listener.startup() + + for msg in listener: + await msg_processor(msg) + ``` + """ connection: Connection From 740d3bf3c27a69f2ff47345fede562e7314c474f Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 15:29:23 +0100 Subject: [PATCH 10/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- python/psqlpy/_internal/__init__.pyi | 14 +- python/psqlpy/_internal/exceptions.pyi | 12 + python/psqlpy/exceptions.py | 8 + python/tests/conftest.py | 22 ++ python/tests/test_listener.py | 312 +++++++++++++++++++++++++ src/driver/listener/core.rs | 38 ++- src/driver/listener/structs.rs | 8 +- src/exceptions/python_errors.rs | 17 ++ 8 files changed, 412 insertions(+), 19 deletions(-) create mode 100644 python/tests/test_listener.py diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 04335995..c4965e96 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1828,7 +1828,7 @@ class Listener: self: Self, channel: str, callback: Callable[ - [str, str, int, Connection], + [Connection, str, str, int], Awaitable[None], ], ) -> None: @@ -1837,12 +1837,14 @@ class Listener: Callback must be async function and have signature like this: ```python async def callback( - channel: str, - payload: str, - process_id: str, connection: Connection, + payload: str, + channel: str, + process_id: int, ) -> None: ... ``` + + Callback parameters are passed as args. """ async def clear_channel_callbacks(self, channel: str) -> None: @@ -1852,12 +1854,14 @@ class Listener: - `channel`: name of the channel. """ + async def clear_all_channels(self) -> None: + """Clear all channels callbacks.""" + def listen(self: Self) -> None: """Start listening. Start actual listening. In the background it creates task in Rust event loop. - You must save returned Future to the array. """ def abort_listen(self: Self) -> None: diff --git a/python/psqlpy/_internal/exceptions.pyi b/python/psqlpy/_internal/exceptions.pyi index bad0fc8c..a0588e9f 100644 --- a/python/psqlpy/_internal/exceptions.pyi +++ b/python/psqlpy/_internal/exceptions.pyi @@ -79,3 +79,15 @@ class PyToRustValueMappingError(RustPSQLDriverPyBaseError): You can get this exception when executing queries with parameters. So, if there are no parameters for the query, don't handle this error. """ + +class BaseListenerError(RustPSQLDriverPyBaseError): + """Base error for all Listener errors.""" + +class ListenerStartError(BaseListenerError): + """Error if listener start failed.""" + +class ListenerClosedError(BaseListenerError): + """Error if listener manipulated but it's closed.""" + +class ListenerCallbackError(BaseListenerError): + """Error if callback passed to listener isn't a coroutine.""" diff --git a/python/psqlpy/exceptions.py b/python/psqlpy/exceptions.py index c8240b9e..2d981ef3 100644 --- a/python/psqlpy/exceptions.py +++ b/python/psqlpy/exceptions.py @@ -2,6 +2,7 @@ BaseConnectionError, BaseConnectionPoolError, BaseCursorError, + BaseListenerError, BaseTransactionError, ConnectionClosedError, ConnectionExecuteError, @@ -12,6 +13,9 @@ CursorCloseError, CursorFetchError, CursorStartError, + ListenerCallbackError, + ListenerClosedError, + ListenerStartError, MacAddrConversionError, PyToRustValueMappingError, RustPSQLDriverPyBaseError, @@ -29,6 +33,7 @@ "BaseConnectionError", "BaseConnectionPoolError", "BaseCursorError", + "BaseListenerError", "BaseTransactionError", "ConnectionClosedError", "ConnectionExecuteError", @@ -39,6 +44,9 @@ "CursorClosedError", "CursorFetchError", "CursorStartError", + "ListenerCallbackError", + "ListenerClosedError", + "ListenerStartError", "MacAddrConversionError", "PyToRustValueMappingError", "RustPSQLDriverPyBaseError", diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 515ff87d..bfa3f650 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -68,6 +68,11 @@ def table_name() -> str: return random_string() +@pytest.fixture +def listener_table_name() -> str: + return random_string() + + @pytest.fixture def number_database_records() -> int: return random.randint(10, 35) @@ -137,6 +142,23 @@ async def create_default_data_for_tests( ) +@pytest.fixture +async def create_table_for_listener_tests( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> AsyncGenerator[None, None]: + await psql_pool.execute( + f"CREATE TABLE {listener_table_name}" + f"(id SERIAL, payload VARCHAR(255)," + f"channel VARCHAR(255), process_id INT)", + ) + + yield + await psql_pool.execute( + f"DROP TABLE {listener_table_name}", + ) + + @pytest.fixture async def test_cursor( psql_pool: ConnectionPool, diff --git a/python/tests/test_listener.py b/python/tests/test_listener.py new file mode 100644 index 00000000..46722ca1 --- /dev/null +++ b/python/tests/test_listener.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import asyncio +import typing + +import pytest +from psqlpy.exceptions import ListenerStartError + +if typing.TYPE_CHECKING: + from psqlpy import Connection, ConnectionPool, Listener + +pytestmark = pytest.mark.anyio + + +TEST_CHANNEL = "test_channel" +TEST_PAYLOAD = "test_payload" + + +async def construct_listener( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> Listener: + listener = psql_pool.listener() + await listener.add_callback( + channel=TEST_CHANNEL, + callback=construct_insert_callback( + listener_table_name=listener_table_name, + ), + ) + return listener + + +def construct_insert_callback( + listener_table_name: str, +) -> typing.Callable[ + [Connection, str, str, int], + typing.Awaitable[None], +]: + async def callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, + ) -> None: + await connection.execute( + querystring=f"INSERT INTO {listener_table_name} VALUES (1, $1, $2, $3)", + parameters=( + payload, + channel, + process_id, + ), + ) + + return callback + + +async def notify( + psql_pool: ConnectionPool, + channel: str = TEST_CHANNEL, + with_delay: bool = False, +) -> None: + if with_delay: + await asyncio.sleep(0.5) + + await psql_pool.execute(f"NOTIFY {channel}, '{TEST_PAYLOAD}'") + + +async def check_insert_callback( + psql_pool: ConnectionPool, + listener_table_name: str, + is_insert_exist: bool = True, + number_of_data: int = 1, +) -> None: + test_data_seq = ( + await psql_pool.execute( + f"SELECT * FROM {listener_table_name}", + ) + ).result() + + if is_insert_exist: + assert len(test_data_seq) == number_of_data + else: + assert not len(test_data_seq) + return + + data_record = test_data_seq[0] + + assert data_record["payload"] == TEST_PAYLOAD + assert data_record["channel"] == TEST_CHANNEL + + +async def clear_test_table( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + await psql_pool.execute( + f"DELETE FROM {listener_table_name}", + ) + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_listen( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + """Test that single connection can execute queries.""" + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + listener.listen() + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_asynciterator( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + + asyncio.create_task( # noqa: RUF006 + notify( + psql_pool=psql_pool, + with_delay=True, + ), + ) + + async for listener_msg in listener: + assert listener_msg.channel == TEST_CHANNEL + assert listener_msg.payload == TEST_PAYLOAD + break + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_abort( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + listener.listen() + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + listener.abort_listen() + + await clear_test_table( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + is_insert_exist=False, + ) + + +async def test_listener_start_exc( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + with pytest.raises(expected_exception=ListenerStartError): + listener.listen() + + +async def test_listener_double_start_exc( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.startup() + + with pytest.raises(expected_exception=ListenerStartError): + await listener.startup() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_more_than_one_callback( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + additional_channel = "test_channel_2" + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + await listener.add_callback( + channel=additional_channel, + callback=construct_insert_callback( + listener_table_name=listener_table_name, + ), + ) + await listener.startup() + listener.listen() + + for channel in [TEST_CHANNEL, additional_channel]: + await notify( + psql_pool=psql_pool, + channel=channel, + ) + + await asyncio.sleep(0.5) + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + number_of_data=2, + ) + + query_result = await psql_pool.execute( + querystring=(f"SELECT * FROM {listener_table_name} WHERE channel = $1"), + parameters=(additional_channel,), + ) + + data_result = query_result.result()[0] + + assert data_result["channel"] == additional_channel + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_clear_callbacks( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + await listener.startup() + listener.listen() + + await listener.clear_channel_callbacks( + channel=TEST_CHANNEL, + ) + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + is_insert_exist=False, + ) + + listener.abort_listen() + + +@pytest.mark.usefixtures("create_table_for_listener_tests") +async def test_listener_clear_all_callbacks( + psql_pool: ConnectionPool, + listener_table_name: str, +) -> None: + listener = await construct_listener( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + ) + + await listener.startup() + listener.listen() + + await listener.clear_all_channels() + + await notify(psql_pool=psql_pool) + await asyncio.sleep(0.5) + + await check_insert_callback( + psql_pool=psql_pool, + listener_table_name=listener_table_name, + is_insert_exist=False, + ) + + listener.abort_listen() diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 6756eadf..984e8a27 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -144,7 +144,7 @@ impl Listener { let py_future = Python::with_gil(move |gil| { rustdriver_future(gil, async move { { - call_listen(&is_listened_clone, &listen_query_clone, &client).await?; + execute_listen(&is_listened_clone, &listen_query_clone, &client).await?; }; let next_element = { let mut write_receiver = receiver.write().await; @@ -198,7 +198,11 @@ impl Listener { let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); - let connection = stream.forward(transmitter).map(|r| r.unwrap()); + let connection = stream.forward(transmitter).map(|r| { + r.map_err(|_| { + RustPSQLDriverError::ListenerStartError("Cannot startup the listener".into()) + }) + }); tokio_runtime().spawn(connection); self.receiver = Some(Arc::new(RwLock::new(receiver))); @@ -216,11 +220,7 @@ impl Listener { channel: String, callback: Py, ) -> RustPSQLDriverPyResult<()> { - let callback_clone = callback.clone(); - - let is_coro = is_coroutine_function(callback_clone)?; - - if !is_coro { + if !is_coroutine_function(callback.clone())? { return Err(RustPSQLDriverError::ListenerCallbackError); } @@ -247,14 +247,28 @@ impl Listener { self.update_listen_query().await; } + async fn clear_all_channels(&mut self) { + { + let mut write_channel_callbacks = self.channel_callbacks.write().await; + write_channel_callbacks.clear_all(); + } + + self.update_listen_query().await; + } + fn listen(&mut self) -> RustPSQLDriverPyResult<()> { let Some(client) = self.connection.db_client() else { - return Err(RustPSQLDriverError::BaseConnectionError("test".into())); + return Err(RustPSQLDriverError::ListenerStartError( + "Cannot start listening, underlying connection doesn't exist".into(), + )); }; - let connection = self.connection.clone(); let Some(receiver) = self.receiver.clone() else { - return Err(RustPSQLDriverError::BaseConnectionError("test".into())); + return Err(RustPSQLDriverError::ListenerStartError( + "Cannot start listening, underlying connection doesn't exist".into(), + )); }; + + let connection = self.connection.clone(); let listen_query_clone = self.listen_query.clone(); let is_listened_clone = self.is_listened.clone(); @@ -263,7 +277,7 @@ impl Listener { let jh: JoinHandle> = tokio_runtime().spawn(async move { loop { { - call_listen(&is_listened_clone, &listen_query_clone, &client).await?; + execute_listen(&is_listened_clone, &listen_query_clone, &client).await?; }; let next_element = { @@ -314,7 +328,7 @@ async fn dispatch_callback( Ok(()) } -async fn call_listen( +async fn execute_listen( is_listened: &Arc>, listen_query: &Arc>, client: &Arc, diff --git a/src/driver/listener/structs.rs b/src/driver/listener/structs.rs index d041258c..4d53a408 100644 --- a/src/driver/listener/structs.rs +++ b/src/driver/listener/structs.rs @@ -34,6 +34,10 @@ impl ChannelCallbacks { self.0.remove(channel); } + pub fn clear_all(&mut self) { + self.0.clear(); + } + #[must_use] pub fn retrieve_all_channels(&self) -> Vec<&String> { self.0.keys().collect::>() @@ -133,10 +137,10 @@ impl ListenerCallback { .call1( py, ( - lister_notification.channel, + connection, lister_notification.payload, + lister_notification.channel, lister_notification.process_id, - connection, ), ) .map_err(|_| RustPSQLDriverError::ListenerCallbackError)?; diff --git a/src/exceptions/python_errors.rs b/src/exceptions/python_errors.rs index c30ea70d..2f30bf1f 100644 --- a/src/exceptions/python_errors.rs +++ b/src/exceptions/python_errors.rs @@ -142,6 +142,7 @@ create_exception!( create_exception!(psqlpy.exceptions, SSLError, RustPSQLDriverPyBaseError); #[allow(clippy::missing_errors_doc)] +#[allow(clippy::too_many_lines)] pub fn python_exceptions_module(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add( "RustPSQLDriverPyBaseError", @@ -232,5 +233,21 @@ pub fn python_exceptions_module(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> "MacAddrConversionError", py.get_type_bound::(), )?; + pymod.add( + "BaseListenerError", + py.get_type_bound::(), + )?; + pymod.add( + "ListenerStartError", + py.get_type_bound::(), + )?; + pymod.add( + "ListenerClosedError", + py.get_type_bound::(), + )?; + pymod.add( + "ListenerCallbackError", + py.get_type_bound::(), + )?; Ok(()) } From 007296d1ace948b27bf39239ed6a70b46ffb5091 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 16:08:09 +0100 Subject: [PATCH 11/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- .pre-commit-config.yaml | 2 - Cargo.lock | 24 +++++------ Cargo.toml | 11 +---- src/common.rs | 4 +- src/driver/connection.rs | 2 +- src/driver/cursor.rs | 2 +- src/driver/listener/core.rs | 2 +- src/driver/transaction.rs | 2 +- src/driver/utils.rs | 2 +- src/exceptions/python_errors.rs | 72 +++++++++++++-------------------- src/query_result.rs | 10 ++--- src/runtime.rs | 4 +- 12 files changed, 57 insertions(+), 80 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7bd77650..bef1e2c0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,8 +58,6 @@ repos: - clippy::all - -W - clippy::pedantic - - -D - - warnings - id: check types: diff --git a/Cargo.lock b/Cargo.lock index ad9abd60..a63c21a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1052,9 +1052,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" +checksum = "57fe09249128b3173d092de9523eaa75136bf7ba85e0d69eca241c7939c933cc" dependencies = [ "cfg-if", "chrono", @@ -1072,8 +1072,8 @@ dependencies = [ [[package]] name = "pyo3-async-runtimes" -version = "0.21.0" -source = "git+https://github.com/chandr-andr/pyo3-async-runtimes.git?branch=main#284bd36d0426a988026f878cae22abdb179795e6" +version = "0.23.0" +source = "git+https://github.com/chandr-andr/pyo3-async-runtimes.git?branch=psqlpy#c2b8441b4910b0b5100536b23c7a2fd43f9eacd0" dependencies = [ "futures", "once_cell", @@ -1084,9 +1084,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" +checksum = "1cd3927b5a78757a0d71aa9dff669f903b1eb64b54142a9bd9f757f8fde65fd7" dependencies = [ "once_cell", "target-lexicon", @@ -1094,9 +1094,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" +checksum = "dab6bb2102bd8f991e7749f130a70d05dd557613e39ed2deeee8e9ca0c4d548d" dependencies = [ "libc", "pyo3-build-config", @@ -1104,9 +1104,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" +checksum = "91871864b353fd5ffcb3f91f2f703a22a9797c91b9ab497b1acac7b07ae509c7" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1116,9 +1116,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.5" +version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" +checksum = "43abc3b80bc20f3facd86cd3c60beed58c3e2aa26213f3cda368de39c60a27e4" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 6890e1b3..4e2d8977 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,15 +10,8 @@ crate-type = ["cdylib"] [dependencies] deadpool-postgres = { git = "https://github.com/chandr-andr/deadpool.git", branch = "psqlpy" } -pyo3 = { version = "*", features = [ - "chrono", - "experimental-async", - "rust_decimal", - "py-clone", - "gil-refs", - "macros", -] } -pyo3-async-runtimes = { git = "https://github.com/chandr-andr/pyo3-async-runtimes.git", branch = "main", features = [ +pyo3 = { version = "0.23.4", features = ["chrono", "experimental-async", "rust_decimal", "py-clone", "macros"] } +pyo3-async-runtimes = { git = "https://github.com/chandr-andr/pyo3-async-runtimes.git", branch = "psqlpy", features = [ "tokio-runtime", ] } diff --git a/src/common.rs b/src/common.rs index aa98a6e2..8dc70fc3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -24,10 +24,10 @@ pub fn add_module( child_mod_name: &'static str, child_mod_builder: impl FnOnce(Python<'_>, &Bound<'_, PyModule>) -> PyResult<()>, ) -> PyResult<()> { - let sub_module = PyModule::new_bound(py, child_mod_name)?; + let sub_module = PyModule::new(py, child_mod_name)?; child_mod_builder(py, &sub_module)?; parent_mod.add_submodule(&sub_module)?; - py.import_bound("sys")?.getattr("modules")?.set_item( + py.import("sys")?.getattr("modules")?.set_item( format!("{}.{}", parent_mod.name()?, child_mod_name), sub_module, )?; diff --git a/src/driver/connection.rs b/src/driver/connection.rs index f9719c8d..f10328c2 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -177,7 +177,7 @@ impl Connection { let (is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { ( exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), ) }); diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 89fa7428..7368d29a 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -179,7 +179,7 @@ impl Cursor { self_.closed, self_.cursor_name.clone(), exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), ) }); diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 984e8a27..3c5838ea 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -104,7 +104,7 @@ impl Listener { ( self_.connection.db_client(), exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), ) }); diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index b7ffe57d..3fa59e4d 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -236,7 +236,7 @@ impl Transaction { ( self_.check_is_transaction_ready(), exception.is_none(gil), - PyErr::from_value_bound(exception.into_bound(gil)), + PyErr::from_value(exception.into_bound(gil)), self_.db_client.clone(), ) }); diff --git a/src/driver/utils.rs b/src/driver/utils.rs index c7f8d10b..635b17ef 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -226,7 +226,7 @@ pub fn build_manager( /// 2) extract boolean pub fn is_coroutine_function(function: Py) -> RustPSQLDriverPyResult { let is_coroutine_function: bool = Python::with_gil(|py| { - let inspect = py.import_bound("inspect")?; + let inspect = py.import("inspect")?; let is_cor = inspect .call_method1("iscoroutinefunction", (function,)) diff --git a/src/exceptions/python_errors.rs b/src/exceptions/python_errors.rs index 2f30bf1f..4d3798cb 100644 --- a/src/exceptions/python_errors.rs +++ b/src/exceptions/python_errors.rs @@ -3,6 +3,7 @@ use pyo3::{ types::{PyModule, PyModuleMethods}, Bound, PyResult, Python, }; + // Main exception. create_exception!( psqlpy.exceptions, @@ -146,108 +147,93 @@ create_exception!(psqlpy.exceptions, SSLError, RustPSQLDriverPyBaseError); pub fn python_exceptions_module(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add( "RustPSQLDriverPyBaseError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "BaseConnectionPoolError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionPoolBuildError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionPoolConfigurationError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionPoolExecuteError", - py.get_type_bound::(), + py.get_type::(), )?; - pymod.add( - "BaseConnectionError", - py.get_type_bound::(), - )?; + pymod.add("BaseConnectionError", py.get_type::())?; pymod.add( "ConnectionExecuteError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "ConnectionClosedError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "BaseTransactionError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionBeginError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionCommitError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionRollbackError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionSavepointError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionExecuteError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "TransactionClosedError", - py.get_type_bound::(), + py.get_type::(), )?; - pymod.add("BaseCursorError", py.get_type_bound::())?; - pymod.add("CursorStartError", py.get_type_bound::())?; - pymod.add("CursorCloseError", py.get_type_bound::())?; - pymod.add("CursorFetchError", py.get_type_bound::())?; - pymod.add( - "CursorClosedError", - py.get_type_bound::(), - )?; + pymod.add("BaseCursorError", py.get_type::())?; + pymod.add("CursorStartError", py.get_type::())?; + pymod.add("CursorCloseError", py.get_type::())?; + pymod.add("CursorFetchError", py.get_type::())?; + pymod.add("CursorClosedError", py.get_type::())?; pymod.add( "RustToPyValueMappingError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "PyToRustValueMappingError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "UUIDValueConvertError", - py.get_type_bound::(), + py.get_type::(), )?; pymod.add( "MacAddrConversionError", - py.get_type_bound::(), - )?; - pymod.add( - "BaseListenerError", - py.get_type_bound::(), - )?; - pymod.add( - "ListenerStartError", - py.get_type_bound::(), - )?; - pymod.add( - "ListenerClosedError", - py.get_type_bound::(), + py.get_type::(), )?; + pymod.add("BaseListenerError", py.get_type::())?; + pymod.add("ListenerStartError", py.get_type::())?; + pymod.add("ListenerClosedError", py.get_type::())?; pymod.add( "ListenerCallbackError", - py.get_type_bound::(), + py.get_type::(), )?; Ok(()) } diff --git a/src/query_result.rs b/src/query_result.rs index 06299b86..162be3b5 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -16,7 +16,7 @@ fn row_to_dict<'a>( postgres_row: &'a Row, custom_decoders: &Option>, ) -> RustPSQLDriverPyResult> { - let python_dict = PyDict::new_bound(py); + let python_dict = PyDict::new(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; python_dict.set_item(column.name().to_object(py), python_type)?; @@ -85,7 +85,7 @@ impl PSQLDriverPyQueryResult { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?; - let convert_class_inst = as_class.call_bound(py, (), Some(&pydict))?; + let convert_class_inst = as_class.call(py, (), Some(&pydict))?; res.push(convert_class_inst); } @@ -109,7 +109,7 @@ impl PSQLDriverPyQueryResult { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &custom_decoders)?; - let row_factory_class = row_factory.call_bound(py, (pydict,), None)?; + let row_factory_class = row_factory.call(py, (pydict,), None)?; res.push(row_factory_class); } Ok(res.to_object(py)) @@ -170,7 +170,7 @@ impl PSQLDriverSinglePyQueryResult { as_class: Py, ) -> RustPSQLDriverPyResult> { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?; - Ok(as_class.call_bound(py, (), Some(&pydict))?) + Ok(as_class.call(py, (), Some(&pydict))?) } /// Convert result from database with function passed from Python. @@ -188,6 +188,6 @@ impl PSQLDriverSinglePyQueryResult { custom_decoders: Option>, ) -> RustPSQLDriverPyResult> { let pydict = row_to_dict(py, &self.inner, &custom_decoders)?.to_object(py); - Ok(row_factory.call_bound(py, (pydict,), None)?) + Ok(row_factory.call(py, (pydict,), None)?) } } diff --git a/src/runtime.rs b/src/runtime.rs index 21365d4f..05889d99 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,5 +1,5 @@ use futures_util::Future; -use pyo3::{IntoPy, Py, PyAny, PyObject, Python}; +use pyo3::{IntoPyObject, Py, PyAny, Python}; use crate::exceptions::rust_errors::RustPSQLDriverPyResult; @@ -21,7 +21,7 @@ pub fn tokio_runtime() -> &'static tokio::runtime::Runtime { pub fn rustdriver_future(py: Python<'_>, future: F) -> RustPSQLDriverPyResult> where F: Future> + Send + 'static, - T: IntoPy, + T: for<'py> IntoPyObject<'py>, { let res = pyo3_async_runtimes::tokio::future_into_py(py, async { future.await.map_err(Into::into) }) From df348b117a0944fe178270aa43834f4391549005 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 16:08:37 +0100 Subject: [PATCH 12/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7038d412..76269460 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -30,7 +30,7 @@ jobs: - uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - args: -p psqlpy --all-features -- -W clippy::all -W clippy::pedantic -D warnings + args: -p psqlpy --all-features -- -W clippy::all -W clippy::pedantic pytest: name: ${{matrix.job.os}}-${{matrix.py_version}} strategy: From a08e81673a939e5eb81f3fb7ac42507a419a5060 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 18:37:51 +0100 Subject: [PATCH 13/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- docs/.vuepress/sidebar.ts | 1 + docs/components/connection_pool.md | 12 ++ docs/components/listener.md | 204 ++++++++++++++++++++++++++ python/psqlpy/_internal/__init__.pyi | 8 +- src/driver/connection_pool.rs | 2 +- src/driver/connection_pool_builder.rs | 2 +- src/driver/listener/core.rs | 25 +++- src/driver/utils.rs | 4 +- 8 files changed, 250 insertions(+), 8 deletions(-) create mode 100644 docs/components/listener.md diff --git a/docs/.vuepress/sidebar.ts b/docs/.vuepress/sidebar.ts index 81da5e8f..133833f7 100644 --- a/docs/.vuepress/sidebar.ts +++ b/docs/.vuepress/sidebar.ts @@ -23,6 +23,7 @@ export default sidebar({ "connection", "transaction", "cursor", + "listener", "results", "exceptions", ], diff --git a/docs/components/connection_pool.md b/docs/components/connection_pool.md index 15369612..bcee4893 100644 --- a/docs/components/connection_pool.md +++ b/docs/components/connection_pool.md @@ -254,6 +254,18 @@ This is the preferable way to work with the PostgreSQL. ::: +### Listener + +Create a new instance of a listener. + +```python +async def main() -> None: + ... + listener = db_pool.listener() +``` +``` + + ### Close To close the connection pool at the stop of your application. diff --git a/docs/components/listener.md b/docs/components/listener.md new file mode 100644 index 00000000..06dd0955 --- /dev/null +++ b/docs/components/listener.md @@ -0,0 +1,204 @@ +--- +title: Listener +--- + +`Listener` object allows users to work with [LISTEN](https://www.postgresql.org/docs/current/sql-listen.html)/[NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) functionality. + +## Usage + +There two ways of using `Listener` object: +- Async iterator +- Background task + +::: tabs + +@tab Background task +```python +from psqlpy import ConnectionPool, Connection, Listener + + +db_pool = ConnectionPool( + dsn="postgres://postgres:postgres@localhost:5432/postgres", +) + +async def test_channel_callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, +) -> None: + # do some important staff + ... + +async def main() -> None: + # Create listener object + listener: Listener = db_pool.listener() + + # Add channel to listen and callback for it. + await listener.add_callback( + channel="test_channel", + callback=test_channel_callback, + ) + + # Startup the listener + await listener.startup() + + # Start listening. + # `listen` method isn't blocking, it returns None and starts background + # task in the Rust event loop. + listener.listen() + + # You can stop listening. + listener.abort_listen() +``` + +@tab Async Iterator +```python +from psqlpy import ( + ConnectionPool, + Connection, + Listener, + ListenerNotificationMsg, +) + + +db_pool = ConnectionPool( + dsn="postgres://postgres:postgres@localhost:5432/postgres", +) + +async def main() -> None: + # Create listener object + listener: Listener = db_pool.listener() + + # Startup the listener + await listener.startup() + + listener_msg: ListenerNotificationMsg + async for listener_msg in listener: + print(listener_msg) +``` + +::: + +## Listener attributes + +- `connection`: Instance of `Connection`. +If `startup` wasn't called, raises `ListenerStartError`. + +- `is_started`: Flag that shows whether the `Listener` is running or not. + +## Listener methods + +### Startup + +Startup `Listener` instance and can be called once or again only after `shutdown`. + +::: important +`Listener` must be started up. +::: + +```python +async def main() -> None: + listener: Listener = db_pool.listener() + + await listener.startup() +``` + +### Shutdown +Abort listen (if called) and release underlying connection. + +```python +async def main() -> None: + listener: Listener = db_pool.listener() + + await listener.startup() + await listener.shutdown() +``` + +### Add Callback + +#### Parameters: +- `channel`: name of the channel to listen. +- `callback`: coroutine callback. + +Add new callback to the channel, can be called more than 1 times. + +Callback signature is like this: +```python +from psqlpy import Connection + +async def callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, +) -> None: + ... +``` + +Parameters for callback are based like `args`, so this signature is correct to: +```python +async def callback( + connection: Connection, + *args, +) -> None: + ... +``` + +**Example:** +```python +async def test_channel_callback( + connection: Connection, + payload: str, + channel: str, + process_id: int, +) -> None: + ... + +async def main() -> None: + listener = db_pool.listener() + + await listener.add_callback( + channel="test_channel", + callback=test_channel_callback, + ) +``` + +### Clear Channel Callbacks + +#### Parameters: +- `channel`: name of the channel + +Remove all callbacks for the channel + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.clear_channel_callbacks() +``` + +### Clear All Channels +Clear all channels and callbacks. + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.clear_all_channels() +``` + +### Listen +Start listening. + +It's a non-blocking operation. +In the background it creates task in Rust event loop. + +```python +async def main() -> None: + listener = db_pool.listener() + await listener.startup() + await listener.listen() +``` + +### Abort Listen +Abort listen. +If `listen()` method was called, stop listening, else don't do anything. diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index c4965e96..42b836b2 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -1809,6 +1809,7 @@ class Listener: """ connection: Connection + is_started: bool def __aiter__(self: Self) -> Self: ... async def __anext__(self: Self) -> ListenerNotificationMsg: ... @@ -1824,6 +1825,11 @@ class Listener: Each listener MUST be started up. """ + async def shutdown(self: Self) -> None: + """Shutdown the listener. + + Abort listen and release underlying connection. + """ async def add_callback( self: Self, channel: str, @@ -1844,7 +1850,7 @@ class Listener: ) -> None: ... ``` - Callback parameters are passed as args. + Callback parameters are passed as args on the Rust side. """ async def clear_channel_callbacks(self, channel: str) -> None: diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index e4e2e953..4f38407d 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -127,7 +127,7 @@ pub fn connect( let mgr: Manager = build_manager( mgr_config, pg_config.clone(), - build_tls(&ca_file, ssl_mode)?, + build_tls(&ca_file, &ssl_mode)?, ); let mut db_pool_builder = Pool::builder(mgr); diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index 963da555..e0610942 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -53,7 +53,7 @@ impl ConnectionPoolBuilder { let mgr: Manager = build_manager( mgr_config, self.config.clone(), - build_tls(&self.ca_file, self.ssl_mode)?, + build_tls(&self.ca_file, &self.ssl_mode)?, ); let mut db_pool_builder = Pool::builder(mgr); diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 3c5838ea..c8fd271c 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -161,8 +161,19 @@ impl Listener { } #[getter] - fn connection(&self) -> Connection { - self.connection.clone() + fn is_started(&self) -> bool { + self.is_started + } + + #[getter] + fn connection(&self) -> RustPSQLDriverPyResult { + if !self.is_started { + return Err(RustPSQLDriverError::ListenerStartError( + "Listener isn't started up".into(), + )); + } + + Ok(self.connection.clone()) } async fn startup(&mut self) -> RustPSQLDriverPyResult<()> { @@ -172,7 +183,7 @@ impl Listener { )); } - let tls_ = build_tls(&self.ca_file.clone(), self.ssl_mode)?; + let tls_ = build_tls(&self.ca_file, &self.ssl_mode)?; let mut builder = SslConnector::builder(SslMethod::tls())?; builder.set_verify(SslVerifyMode::NONE); @@ -214,6 +225,14 @@ impl Listener { Ok(()) } + async fn shutdown(&mut self) { + self.abort_listen(); + std::mem::take(&mut self.connection); + std::mem::take(&mut self.receiver); + + self.is_started = false; + } + #[pyo3(signature = (channel, callback))] async fn add_callback( &mut self, diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 635b17ef..3d0d59e3 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -181,7 +181,7 @@ pub enum ConfiguredTLS { /// May return Err Result if cannot create builder. pub fn build_tls( ca_file: &Option, - ssl_mode: Option, + ssl_mode: &Option, ) -> RustPSQLDriverPyResult { if let Some(ca_file) = ca_file { let mut builder = SslConnector::builder(SslMethod::tls())?; @@ -190,7 +190,7 @@ pub fn build_tls( builder.build(), ))); } else if let Some(ssl_mode) = ssl_mode { - if ssl_mode == common_options::SslMode::Require { + if *ssl_mode == common_options::SslMode::Require { let mut builder = SslConnector::builder(SslMethod::tls())?; builder.set_verify(SslVerifyMode::NONE); return Ok(ConfiguredTLS::TlsConnector(MakeTlsConnector::new( From b629c4c98494d7c4d292ef921b3335cff4119f4f Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 18:40:11 +0100 Subject: [PATCH 14/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- docs/components/components_overview.md | 1 + docs/components/listener.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/components/components_overview.md b/docs/components/components_overview.md index 30cf84e3..90b05b70 100644 --- a/docs/components/components_overview.md +++ b/docs/components/components_overview.md @@ -8,6 +8,7 @@ title: Components - `Connection`: represents single database connection, can be retrieved from `ConnectionPool`. - `Transaction`: represents database transaction, can be made from `Connection`. - `Cursor`: represents database cursor, can be made from `Transaction`. +- `Listener`: object to work with [LISTEN](https://www.postgresql.org/docs/current/sql-listen.html)/[NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) functionality, can be mode from `ConnectionPool`. - `QueryResult`: represents list of results from database. - `SingleQueryResult`: represents single result from the database. - `Exceptions`: we have some custom exceptions. diff --git a/docs/components/listener.md b/docs/components/listener.md index 06dd0955..f2ecff30 100644 --- a/docs/components/listener.md +++ b/docs/components/listener.md @@ -6,7 +6,7 @@ title: Listener ## Usage -There two ways of using `Listener` object: +There are two ways of using `Listener` object: - Async iterator - Background task From 01673639b5799febbf9b74421d8ff559a9f69ee5 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 18:50:26 +0100 Subject: [PATCH 15/17] LISTEN/NOTIFY funcionality Signed-off-by: chandr-andr (Kiselev Aleksandr) --- docs/components/exceptions.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/components/exceptions.md b/docs/components/exceptions.md index aac3ecbd..fd2d86fb 100644 --- a/docs/components/exceptions.md +++ b/docs/components/exceptions.md @@ -15,6 +15,7 @@ stateDiagram-v2 RustPSQLDriverPyBaseError --> BaseConnectionError RustPSQLDriverPyBaseError --> BaseTransactionError RustPSQLDriverPyBaseError --> BaseCursorError + RustPSQLDriverPyBaseError --> BaseListenerError RustPSQLDriverPyBaseError --> RustException RustPSQLDriverPyBaseError --> RustToPyValueMappingError RustPSQLDriverPyBaseError --> PyToRustValueMappingError @@ -44,6 +45,11 @@ stateDiagram-v2 [*] --> CursorFetchError [*] --> CursorClosedError } + state BaseListenerError { + [*] --> ListenerStartError + [*] --> ListenerClosedError + [*] --> ListenerCallbackError + } state RustException { [*] --> DriverError [*] --> MacAddrParseError @@ -127,3 +133,15 @@ Error in cursor fetch (any fetch). #### CursorClosedError Error if underlying connection is closed. + +### BaseListenerError +Base error for all Listener errors. + +#### ListenerStartError +Error if listener start failed. + +#### ListenerClosedError +Error if listener manipulated but it's closed + +#### ListenerCallbackError +Error if callback passed to listener isn't a coroutine From 11ca2708546846bc0284cb652d6b237e14bcc673 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 19:17:18 +0100 Subject: [PATCH 16/17] Bumped version to 0.9.0 Signed-off-by: chandr-andr (Kiselev Aleksandr) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a63c21a3..abf26770 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -997,7 +997,7 @@ dependencies = [ [[package]] name = "psqlpy" -version = "0.8.7" +version = "0.9.0" dependencies = [ "byteorder", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 4e2d8977..5e7743de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "psqlpy" -version = "0.8.7" +version = "0.9.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 026266303e56396b7760da796a666740fea818c8 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 25 Jan 2025 19:53:56 +0100 Subject: [PATCH 17/17] Bumped version to 0.9.0 Signed-off-by: chandr-andr (Kiselev Aleksandr) --- .github/workflows/release.yml | 8 ++++---- .github/workflows/test.yaml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f6161fc4..077354e8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -25,7 +25,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 pypy3.8 pypy3.9 pypy3.10 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 sccache: 'true' manylinux: auto before-script-linux: | @@ -70,7 +70,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v3 @@ -110,7 +110,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 pypy3.8 pypy3.9 pypy3.10 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v3 @@ -164,7 +164,7 @@ jobs: uses: messense/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist -i 3.8 3.9 3.10 3.11 3.12 3.13 pypy3.8 pypy3.9 pypy3.10 + args: --release --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 manylinux: musllinux_1_2 - name: Upload wheels uses: actions/upload-artifact@v3 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 76269460..7d81fbd3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -35,7 +35,7 @@ jobs: name: ${{matrix.job.os}}-${{matrix.py_version}} strategy: matrix: - py_version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] job: - os: ubuntu-latest ssl_cmd: sudo apt-get update && sudo apt-get install libssl-dev openssl