Skip to content

Commit 4008c0d

Browse files
committed
Added context manager support to ConnectionPool
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent a3e24a1 commit 4008c0d

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

python/psqlpy/_internal/__init__.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,14 @@ class ConnectionPool:
12151215
- `ca_file`: Loads trusted root certificates from a file.
12161216
The file should contain a sequence of PEM-formatted CA certificates.
12171217
"""
1218+
def __iter__(self: Self) -> Self: ...
1219+
def __enter__(self: Self) -> Self: ...
1220+
def __exit__(
1221+
self: Self,
1222+
exception_type: type[BaseException] | None,
1223+
exception: BaseException | None,
1224+
traceback: types.TracebackType | None,
1225+
) -> None: ...
12181226
def status(self: Self) -> ConnectionPoolStatus:
12191227
"""Return information about connection pool.
12201228

python/tests/test_connection_pool.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,15 @@ async def test_close_connection_pool() -> None:
175175

176176
with pytest.raises(expected_exception=RustPSQLDriverPyBaseError):
177177
await pg_pool.execute("SELECT 1")
178+
179+
180+
async def test_connection_pool_as_context_manager() -> None:
181+
"""Test connection pool as context manager."""
182+
with ConnectionPool(
183+
dsn="postgres://postgres:postgres@localhost:5432/psqlpy_test",
184+
) as pg_pool:
185+
res = await pg_pool.execute("SELECT 1")
186+
assert res.result()
187+
188+
with pytest.raises(expected_exception=RustPSQLDriverPyBaseError):
189+
await pg_pool.execute("SELECT 1")

src/driver/connection_pool.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::runtime::tokio_runtime;
22
use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod};
33
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
44
use postgres_openssl::MakeTlsConnector;
5-
use pyo3::{pyclass, pyfunction, pymethods, PyAny};
5+
use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny};
66
use std::{sync::Arc, vec};
77
use tokio_postgres::NoTls;
88

@@ -253,6 +253,28 @@ impl ConnectionPool {
253253
)
254254
}
255255

256+
#[must_use]
257+
pub fn __iter__(self_: Py<Self>) -> Py<Self> {
258+
self_
259+
}
260+
261+
#[allow(clippy::needless_pass_by_value)]
262+
fn __enter__(self_: Py<Self>) -> Py<Self> {
263+
self_
264+
}
265+
266+
#[allow(clippy::needless_pass_by_value)]
267+
fn __exit__(
268+
self_: Py<Self>,
269+
_exception_type: Py<PyAny>,
270+
_exception: Py<PyAny>,
271+
_traceback: Py<PyAny>,
272+
) {
273+
pyo3::Python::with_gil(|gil| {
274+
self_.borrow(gil).close();
275+
});
276+
}
277+
256278
#[must_use]
257279
pub fn status(&self) -> ConnectionPoolStatus {
258280
let inner_status = self.0.status();

0 commit comments

Comments
 (0)