Skip to content

Commit 9b62ba2

Browse files
authored
feat: Allow parameters to be tuple or lists (#97)
This is mostly to improve the dev experience to avoid forcing a specific type allowing some duck typing
2 parents 696a23d + b8a0d49 commit 9b62ba2

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

src/lib.rs

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ use tokio::runtime::{Handle, Runtime};
1010

1111
const LEGACY_TRANSACTION_CONTROL: i32 = -1;
1212

13+
enum ListOrTuple<'py> {
14+
List(&'py PyList),
15+
Tuple(&'py PyTuple),
16+
}
17+
18+
struct ListOrTupleIterator<'py> {
19+
index: usize,
20+
inner: &'py ListOrTuple<'py>
21+
}
22+
1323
fn rt() -> Handle {
1424
static RT: OnceLock<Runtime> = OnceLock::new();
1525

@@ -286,7 +296,7 @@ impl Connection {
286296
fn execute(
287297
self_: PyRef<'_, Self>,
288298
sql: String,
289-
parameters: Option<&PyTuple>,
299+
parameters: Option<ListOrTuple>,
290300
) -> PyResult<Cursor> {
291301
let cursor = Connection::cursor(&self_)?;
292302
rt().block_on(async { execute(&cursor, sql, parameters).await })?;
@@ -300,7 +310,7 @@ impl Connection {
300310
) -> PyResult<Cursor> {
301311
let cursor = Connection::cursor(&self_)?;
302312
for parameters in parameters.unwrap().iter() {
303-
let parameters = parameters.extract::<&PyTuple>()?;
313+
let parameters = parameters.extract::<ListOrTuple>()?;
304314
rt().block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
305315
}
306316
Ok(cursor)
@@ -419,7 +429,7 @@ impl Cursor {
419429
fn execute<'a>(
420430
self_: PyRef<'a, Self>,
421431
sql: String,
422-
parameters: Option<&PyTuple>,
432+
parameters: Option<ListOrTuple>,
423433
) -> PyResult<pyo3::PyRef<'a, Self>> {
424434
rt().block_on(async { execute(&self_, sql, parameters).await })?;
425435
Ok(self_)
@@ -431,7 +441,7 @@ impl Cursor {
431441
parameters: Option<&PyList>,
432442
) -> PyResult<pyo3::PyRef<'a, Cursor>> {
433443
for parameters in parameters.unwrap().iter() {
434-
let parameters = parameters.extract::<&PyTuple>()?;
444+
let parameters = parameters.extract::<ListOrTuple>()?;
435445
rt().block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
436446
}
437447
Ok(self_)
@@ -575,7 +585,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
575585
Ok(())
576586
}
577587

578-
async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) -> PyResult<()> {
588+
async fn execute<'py>(
589+
cursor: &Cursor,
590+
sql: String,
591+
parameters: Option<ListOrTuple<'py>>,
592+
) -> PyResult<()> {
579593
if cursor.conn.borrow().as_ref().is_none() {
580594
return Err(PyValueError::new_err("Connection already closed"));
581595
}
@@ -599,7 +613,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
599613
} else if let Ok(value) = param.extract::<&[u8]>() {
600614
libsql_core::Value::Blob(value.to_vec())
601615
} else {
602-
return Err(PyValueError::new_err("Unsupported parameter type"));
616+
return Err(PyValueError::new_err(format!(
617+
"Unsupported parameter type {}",
618+
param.to_string()
619+
)));
603620
};
604621
params.push(param);
605622
}
@@ -676,6 +693,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult
676693

677694
create_exception!(libsql, Error, pyo3::exceptions::PyException);
678695

696+
impl<'py> FromPyObject<'py> for ListOrTuple<'py> {
697+
fn extract(ob: &'py PyAny) -> PyResult<Self> {
698+
if let Ok(list) = ob.downcast::<PyList>() {
699+
Ok(ListOrTuple::List(list))
700+
} else if let Ok(tuple) = ob.downcast::<PyTuple>() {
701+
Ok(ListOrTuple::Tuple(tuple))
702+
} else {
703+
Err(PyValueError::new_err(
704+
"Expected a list or tuple for parameters",
705+
))
706+
}
707+
}
708+
}
709+
710+
impl<'py> ListOrTuple<'py> {
711+
pub fn iter(&self) -> ListOrTupleIterator {
712+
ListOrTupleIterator{
713+
index: 0,
714+
inner: self,
715+
}
716+
}
717+
}
718+
719+
impl<'py> Iterator for ListOrTupleIterator<'py> {
720+
type Item = &'py PyAny;
721+
722+
fn next(&mut self) -> Option<Self::Item> {
723+
let rv = match self.inner {
724+
ListOrTuple::List(list) => list.get_item(self.index),
725+
ListOrTuple::Tuple(tuple) => tuple.get_item(self.index),
726+
};
727+
728+
rv.ok().map(|item| {
729+
self.index += 1;
730+
item
731+
})
732+
}
733+
}
679734
#[pymodule]
680735
fn libsql(py: Python, m: &PyModule) -> PyResult<()> {
681736
let _ = tracing_subscriber::fmt::try_init();

tests/test_suite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def test_execute(provider):
2626
conn.execute("INSERT INTO users VALUES (1, 'alice@example.com')")
2727
res = conn.execute("SELECT * FROM users")
2828
assert (1, "alice@example.com") == res.fetchone()
29+
# allow lists for parameters as well
30+
res = conn.execute("SELECT * FROM users WHERE id = ?", [1])
31+
assert (1, "alice@example.com") == res.fetchone()
2932

3033

3134
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])

0 commit comments

Comments
 (0)