Skip to content

Commit 696a23d

Browse files
authored
Add context manager support for database connections (#96)
Implements __enter__ and __exit__ methods to enable using Connection objects as context managers. On clean exit, transactions are automatically committed. On exception, transactions are automatically rolled back. This provides sqlite3-compatible behavior and safer transaction handling. Closes #95
2 parents d76ef4d + 68dd923 commit 696a23d

File tree

3 files changed

+269
-27
lines changed

3 files changed

+269
-27
lines changed

docs/api.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ Rolls back the current transaction and starts a new one.
3232

3333
Closes the database connection.
3434

35+
### `with` statement
36+
37+
Connection objects can be used as context managers to ensure that transactions are properly committed or rolled back. When entering the context, the connection object is returned. When exiting:
38+
- Without exception: automatically commits the transaction
39+
- With exception: automatically rolls back the transaction
40+
41+
This behavior is compatible with Python's `sqlite3` module. Context managers work correctly in both transactional and autocommit modes.
42+
43+
When mixing manual transaction control with context managers, the context manager's commit/rollback will apply to any active transaction at the time of exit. Manual calls to `commit()` or `rollback()` within the context are allowed and will start a new transaction as usual.
44+
3545
### execute(sql, parameters=())
3646

3747
Create a new cursor object and executes the SQL statement.

src/lib.rs

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use pyo3::create_exception;
33
use pyo3::exceptions::PyValueError;
44
use pyo3::prelude::*;
55
use pyo3::types::{PyList, PyTuple};
6-
use std::cell::{OnceCell, RefCell};
6+
use std::cell::RefCell;
77
use std::sync::{Arc, OnceLock};
88
use std::time::Duration;
99
use tokio::runtime::{Handle, Runtime};
@@ -38,14 +38,14 @@ fn is_remote_path(path: &str) -> bool {
3838

3939
#[pyfunction]
4040
#[cfg(not(Py_3_12))]
41-
#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), check_same_thread=true, uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None))]
41+
#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), _check_same_thread=true, _uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None))]
4242
fn connect(
4343
py: Python<'_>,
4444
database: String,
4545
timeout: f64,
4646
isolation_level: Option<String>,
47-
check_same_thread: bool,
48-
uri: bool,
47+
_check_same_thread: bool,
48+
_uri: bool,
4949
sync_url: Option<String>,
5050
sync_interval: Option<f64>,
5151
auth_token: &str,
@@ -56,8 +56,8 @@ fn connect(
5656
database,
5757
timeout,
5858
isolation_level,
59-
check_same_thread,
60-
uri,
59+
_check_same_thread,
60+
_uri,
6161
sync_url,
6262
sync_interval,
6363
auth_token,
@@ -68,14 +68,14 @@ fn connect(
6868

6969
#[pyfunction]
7070
#[cfg(Py_3_12)]
71-
#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), check_same_thread=true, uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None, autocommit = LEGACY_TRANSACTION_CONTROL))]
71+
#[pyo3(signature = (database, timeout=5.0, isolation_level="DEFERRED".to_string(), _check_same_thread=true, _uri=false, sync_url=None, sync_interval=None, auth_token="", encryption_key=None, autocommit = LEGACY_TRANSACTION_CONTROL))]
7272
fn connect(
7373
py: Python<'_>,
7474
database: String,
7575
timeout: f64,
7676
isolation_level: Option<String>,
77-
check_same_thread: bool,
78-
uri: bool,
77+
_check_same_thread: bool,
78+
_uri: bool,
7979
sync_url: Option<String>,
8080
sync_interval: Option<f64>,
8181
auth_token: &str,
@@ -87,8 +87,8 @@ fn connect(
8787
database,
8888
timeout,
8989
isolation_level.clone(),
90-
check_same_thread,
91-
uri,
90+
_check_same_thread,
91+
_uri,
9292
sync_url,
9393
sync_interval,
9494
auth_token,
@@ -111,8 +111,8 @@ fn _connect_core(
111111
database: String,
112112
timeout: f64,
113113
isolation_level: Option<String>,
114-
check_same_thread: bool,
115-
uri: bool,
114+
_check_same_thread: bool,
115+
_uri: bool,
116116
sync_url: Option<String>,
117117
sync_interval: Option<f64>,
118118
auth_token: &str,
@@ -220,7 +220,7 @@ unsafe impl Send for Connection {}
220220

221221
#[pymethods]
222222
impl Connection {
223-
fn close(self_: PyRef<'_, Self>, py: Python<'_>) -> PyResult<()> {
223+
fn close(self_: PyRef<'_, Self>, _py: Python<'_>) -> PyResult<()> {
224224
self_.conn.replace(None);
225225
Ok(())
226226
}
@@ -330,11 +330,14 @@ impl Connection {
330330
fn in_transaction(self_: PyRef<'_, Self>) -> PyResult<bool> {
331331
#[cfg(Py_3_12)]
332332
{
333-
return Ok(
333+
Ok(
334334
!self_.conn.borrow().as_ref().unwrap().is_autocommit() || self_.autocommit == 0
335-
);
335+
)
336+
}
337+
#[cfg(not(Py_3_12))]
338+
{
339+
Ok(!self_.conn.borrow().as_ref().unwrap().is_autocommit())
336340
}
337-
Ok(!self_.conn.borrow().as_ref().unwrap().is_autocommit())
338341
}
339342

340343
#[getter]
@@ -354,6 +357,26 @@ impl Connection {
354357
self_.autocommit = autocommit;
355358
Ok(())
356359
}
360+
361+
fn __enter__(slf: PyRef<'_, Self>) -> PyResult<PyRef<'_, Self>> {
362+
Ok(slf)
363+
}
364+
365+
fn __exit__(
366+
self_: PyRef<'_, Self>,
367+
exc_type: Option<&PyAny>,
368+
_exc_val: Option<&PyAny>,
369+
_exc_tb: Option<&PyAny>,
370+
) -> PyResult<bool> {
371+
if exc_type.is_none() {
372+
// Commit on clean exit
373+
Connection::commit(self_)?;
374+
} else {
375+
// Rollback on error
376+
Connection::rollback(self_)?;
377+
}
378+
Ok(false) // Always propagate exceptions
379+
}
357380
}
358381

359382
#[pyclass]

0 commit comments

Comments
 (0)