Skip to content

Commit b8a0d49

Browse files
committed
feat: Allow parameters to be tuple or lists
This is mostly to improve the dev experience to avoid forcing a specific type allowing some duck typing
1 parent 5188c17 commit b8a0d49

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

src/lib.rs

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,23 @@ use pyo3::create_exception;
33
use pyo3::exceptions::PyValueError;
44
use pyo3::prelude::*;
55
use pyo3::types::{PyList, PyTuple};
6-
use std::cell::{OnceCell, RefCell};
6+
use std::cell::RefCell;
77
use std::sync::{Arc, OnceLock};
88
use std::time::Duration;
99
use tokio::runtime::{Handle, Runtime};
1010

1111
const LEGACY_TRANSACTION_CONTROL: i32 = -1;
1212

13+
enum ListOrTuple<'py> {
14+
List(&'py PyList),
15+
Tuple(&'py PyTuple),
16+
}
17+
18+
struct ListOrTupleIterator<'py> {
19+
index: usize,
20+
inner: &'py ListOrTuple<'py>
21+
}
22+
1323
fn rt() -> Handle {
1424
static RT: OnceLock<Runtime> = OnceLock::new();
1525

@@ -286,7 +296,7 @@ impl Connection {
286296
fn execute(
287297
self_: PyRef<'_, Self>,
288298
sql: String,
289-
parameters: Option<&PyTuple>,
299+
parameters: Option<ListOrTuple>,
290300
) -> PyResult<Cursor> {
291301
let cursor = Connection::cursor(&self_)?;
292302
rt().block_on(async { execute(&cursor, sql, parameters).await })?;
@@ -300,7 +310,7 @@ impl Connection {
300310
) -> PyResult<Cursor> {
301311
let cursor = Connection::cursor(&self_)?;
302312
for parameters in parameters.unwrap().iter() {
303-
let parameters = parameters.extract::<&PyTuple>()?;
313+
let parameters = parameters.extract::<ListOrTuple>()?;
304314
rt().block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
305315
}
306316
Ok(cursor)
@@ -396,7 +406,7 @@ impl Cursor {
396406
fn execute<'a>(
397407
self_: PyRef<'a, Self>,
398408
sql: String,
399-
parameters: Option<&PyTuple>,
409+
parameters: Option<ListOrTuple>,
400410
) -> PyResult<pyo3::PyRef<'a, Self>> {
401411
rt().block_on(async { execute(&self_, sql, parameters).await })?;
402412
Ok(self_)
@@ -408,7 +418,7 @@ impl Cursor {
408418
parameters: Option<&PyList>,
409419
) -> PyResult<pyo3::PyRef<'a, Cursor>> {
410420
for parameters in parameters.unwrap().iter() {
411-
let parameters = parameters.extract::<&PyTuple>()?;
421+
let parameters = parameters.extract::<ListOrTuple>()?;
412422
rt().block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
413423
}
414424
Ok(self_)
@@ -552,7 +562,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
552562
Ok(())
553563
}
554564

555-
async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) -> PyResult<()> {
565+
async fn execute<'py>(
566+
cursor: &Cursor,
567+
sql: String,
568+
parameters: Option<ListOrTuple<'py>>,
569+
) -> PyResult<()> {
556570
if cursor.conn.borrow().as_ref().is_none() {
557571
return Err(PyValueError::new_err("Connection already closed"));
558572
}
@@ -576,7 +590,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
576590
} else if let Ok(value) = param.extract::<&[u8]>() {
577591
libsql_core::Value::Blob(value.to_vec())
578592
} else {
579-
return Err(PyValueError::new_err("Unsupported parameter type"));
593+
return Err(PyValueError::new_err(format!(
594+
"Unsupported parameter type {}",
595+
param.to_string()
596+
)));
580597
};
581598
params.push(param);
582599
}
@@ -653,6 +670,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult
653670

654671
create_exception!(libsql, Error, pyo3::exceptions::PyException);
655672

673+
impl<'py> FromPyObject<'py> for ListOrTuple<'py> {
674+
fn extract(ob: &'py PyAny) -> PyResult<Self> {
675+
if let Ok(list) = ob.downcast::<PyList>() {
676+
Ok(ListOrTuple::List(list))
677+
} else if let Ok(tuple) = ob.downcast::<PyTuple>() {
678+
Ok(ListOrTuple::Tuple(tuple))
679+
} else {
680+
Err(PyValueError::new_err(
681+
"Expected a list or tuple for parameters",
682+
))
683+
}
684+
}
685+
}
686+
687+
impl<'py> ListOrTuple<'py> {
688+
pub fn iter(&self) -> ListOrTupleIterator {
689+
ListOrTupleIterator{
690+
index: 0,
691+
inner: self,
692+
}
693+
}
694+
}
695+
696+
impl<'py> Iterator for ListOrTupleIterator<'py> {
697+
type Item = &'py PyAny;
698+
699+
fn next(&mut self) -> Option<Self::Item> {
700+
let rv = match self.inner {
701+
ListOrTuple::List(list) => list.get_item(self.index),
702+
ListOrTuple::Tuple(tuple) => tuple.get_item(self.index),
703+
};
704+
705+
rv.ok().map(|item| {
706+
self.index += 1;
707+
item
708+
})
709+
}
710+
}
656711
#[pymodule]
657712
fn libsql(py: Python, m: &PyModule) -> PyResult<()> {
658713
let _ = tracing_subscriber::fmt::try_init();

tests/test_suite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def test_execute(provider):
2323
conn.execute("INSERT INTO users VALUES (1, 'alice@example.com')")
2424
res = conn.execute("SELECT * FROM users")
2525
assert (1, "alice@example.com") == res.fetchone()
26+
# allow lists for parameters as well
27+
res = conn.execute("SELECT * FROM users WHERE id = ?", [1])
28+
assert (1, "alice@example.com") == res.fetchone()
2629

2730

2831
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])

0 commit comments

Comments
 (0)