diff --git a/src/lib.rs b/src/lib.rs index ed7d444..4b358d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,13 +3,23 @@ use pyo3::create_exception; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyList, PyTuple}; -use std::cell::{OnceCell, RefCell}; +use std::cell::RefCell; use std::sync::{Arc, OnceLock}; use std::time::Duration; use tokio::runtime::{Handle, Runtime}; const LEGACY_TRANSACTION_CONTROL: i32 = -1; +enum ListOrTuple<'py> { + List(&'py PyList), + Tuple(&'py PyTuple), +} + +struct ListOrTupleIterator<'py> { + index: usize, + inner: &'py ListOrTuple<'py> +} + fn rt() -> Handle { static RT: OnceLock = OnceLock::new(); @@ -286,7 +296,7 @@ impl Connection { fn execute( self_: PyRef<'_, Self>, sql: String, - parameters: Option<&PyTuple>, + parameters: Option, ) -> PyResult { let cursor = Connection::cursor(&self_)?; rt().block_on(async { execute(&cursor, sql, parameters).await })?; @@ -300,7 +310,7 @@ impl Connection { ) -> PyResult { let cursor = Connection::cursor(&self_)?; for parameters in parameters.unwrap().iter() { - let parameters = parameters.extract::<&PyTuple>()?; + let parameters = parameters.extract::()?; rt().block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?; } Ok(cursor) @@ -396,7 +406,7 @@ impl Cursor { fn execute<'a>( self_: PyRef<'a, Self>, sql: String, - parameters: Option<&PyTuple>, + parameters: Option, ) -> PyResult> { rt().block_on(async { execute(&self_, sql, parameters).await })?; Ok(self_) @@ -408,7 +418,7 @@ impl Cursor { parameters: Option<&PyList>, ) -> PyResult> { for parameters in parameters.unwrap().iter() { - let parameters = parameters.extract::<&PyTuple>()?; + let parameters = parameters.extract::()?; rt().block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?; } Ok(self_) @@ -552,7 +562,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> { Ok(()) } -async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) -> PyResult<()> { +async fn execute<'py>( + cursor: &Cursor, + sql: String, + parameters: Option>, +) -> PyResult<()> { if cursor.conn.borrow().as_ref().is_none() { return Err(PyValueError::new_err("Connection already closed")); } @@ -576,7 +590,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) -> } else if let Ok(value) = param.extract::<&[u8]>() { libsql_core::Value::Blob(value.to_vec()) } else { - return Err(PyValueError::new_err("Unsupported parameter type")); + return Err(PyValueError::new_err(format!( + "Unsupported parameter type {}", + param.to_string() + ))); }; params.push(param); } @@ -653,6 +670,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult create_exception!(libsql, Error, pyo3::exceptions::PyException); +impl<'py> FromPyObject<'py> for ListOrTuple<'py> { + fn extract(ob: &'py PyAny) -> PyResult { + if let Ok(list) = ob.downcast::() { + Ok(ListOrTuple::List(list)) + } else if let Ok(tuple) = ob.downcast::() { + Ok(ListOrTuple::Tuple(tuple)) + } else { + Err(PyValueError::new_err( + "Expected a list or tuple for parameters", + )) + } + } +} + +impl<'py> ListOrTuple<'py> { + pub fn iter(&self) -> ListOrTupleIterator { + ListOrTupleIterator{ + index: 0, + inner: self, + } + } +} + +impl<'py> Iterator for ListOrTupleIterator<'py> { + type Item = &'py PyAny; + + fn next(&mut self) -> Option { + let rv = match self.inner { + ListOrTuple::List(list) => list.get_item(self.index), + ListOrTuple::Tuple(tuple) => tuple.get_item(self.index), + }; + + rv.ok().map(|item| { + self.index += 1; + item + }) + } +} #[pymodule] fn libsql(py: Python, m: &PyModule) -> PyResult<()> { let _ = tracing_subscriber::fmt::try_init(); diff --git a/tests/test_suite.py b/tests/test_suite.py index 428f314..dc3b499 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -23,6 +23,9 @@ def test_execute(provider): conn.execute("INSERT INTO users VALUES (1, 'alice@example.com')") res = conn.execute("SELECT * FROM users") assert (1, "alice@example.com") == res.fetchone() + # allow lists for parameters as well + res = conn.execute("SELECT * FROM users WHERE id = ?", [1]) + assert (1, "alice@example.com") == res.fetchone() @pytest.mark.parametrize("provider", ["libsql", "sqlite"])