@@ -10,6 +10,16 @@ use tokio::runtime::{Handle, Runtime};
10
10
11
11
const LEGACY_TRANSACTION_CONTROL : i32 = -1 ;
12
12
13
+ enum ListOrTuple < ' py > {
14
+ List ( & ' py PyList ) ,
15
+ Tuple ( & ' py PyTuple ) ,
16
+ }
17
+
18
+ struct ListOrTupleIterator < ' py > {
19
+ index : usize ,
20
+ inner : & ' py ListOrTuple < ' py >
21
+ }
22
+
13
23
fn rt ( ) -> Handle {
14
24
static RT : OnceLock < Runtime > = OnceLock :: new ( ) ;
15
25
@@ -286,7 +296,7 @@ impl Connection {
286
296
fn execute (
287
297
self_ : PyRef < ' _ , Self > ,
288
298
sql : String ,
289
- parameters : Option < & PyTuple > ,
299
+ parameters : Option < ListOrTuple > ,
290
300
) -> PyResult < Cursor > {
291
301
let cursor = Connection :: cursor ( & self_) ?;
292
302
rt ( ) . block_on ( async { execute ( & cursor, sql, parameters) . await } ) ?;
@@ -300,7 +310,7 @@ impl Connection {
300
310
) -> PyResult < Cursor > {
301
311
let cursor = Connection :: cursor ( & self_) ?;
302
312
for parameters in parameters. unwrap ( ) . iter ( ) {
303
- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
313
+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
304
314
rt ( ) . block_on ( async { execute ( & cursor, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
305
315
}
306
316
Ok ( cursor)
@@ -419,7 +429,7 @@ impl Cursor {
419
429
fn execute < ' a > (
420
430
self_ : PyRef < ' a , Self > ,
421
431
sql : String ,
422
- parameters : Option < & PyTuple > ,
432
+ parameters : Option < ListOrTuple > ,
423
433
) -> PyResult < pyo3:: PyRef < ' a , Self > > {
424
434
rt ( ) . block_on ( async { execute ( & self_, sql, parameters) . await } ) ?;
425
435
Ok ( self_)
@@ -431,7 +441,7 @@ impl Cursor {
431
441
parameters : Option < & PyList > ,
432
442
) -> PyResult < pyo3:: PyRef < ' a , Cursor > > {
433
443
for parameters in parameters. unwrap ( ) . iter ( ) {
434
- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
444
+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
435
445
rt ( ) . block_on ( async { execute ( & self_, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
436
446
}
437
447
Ok ( self_)
@@ -575,7 +585,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
575
585
Ok ( ( ) )
576
586
}
577
587
578
- async fn execute ( cursor : & Cursor , sql : String , parameters : Option < & PyTuple > ) -> PyResult < ( ) > {
588
+ async fn execute < ' py > (
589
+ cursor : & Cursor ,
590
+ sql : String ,
591
+ parameters : Option < ListOrTuple < ' py > > ,
592
+ ) -> PyResult < ( ) > {
579
593
if cursor. conn . borrow ( ) . as_ref ( ) . is_none ( ) {
580
594
return Err ( PyValueError :: new_err ( "Connection already closed" ) ) ;
581
595
}
@@ -599,7 +613,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
599
613
} else if let Ok ( value) = param. extract :: < & [ u8 ] > ( ) {
600
614
libsql_core:: Value :: Blob ( value. to_vec ( ) )
601
615
} else {
602
- return Err ( PyValueError :: new_err ( "Unsupported parameter type" ) ) ;
616
+ return Err ( PyValueError :: new_err ( format ! (
617
+ "Unsupported parameter type {}" ,
618
+ param. to_string( )
619
+ ) ) ) ;
603
620
} ;
604
621
params. push ( param) ;
605
622
}
@@ -676,6 +693,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult
676
693
677
694
create_exception ! ( libsql, Error , pyo3:: exceptions:: PyException ) ;
678
695
696
+ impl < ' py > FromPyObject < ' py > for ListOrTuple < ' py > {
697
+ fn extract ( ob : & ' py PyAny ) -> PyResult < Self > {
698
+ if let Ok ( list) = ob. downcast :: < PyList > ( ) {
699
+ Ok ( ListOrTuple :: List ( list) )
700
+ } else if let Ok ( tuple) = ob. downcast :: < PyTuple > ( ) {
701
+ Ok ( ListOrTuple :: Tuple ( tuple) )
702
+ } else {
703
+ Err ( PyValueError :: new_err (
704
+ "Expected a list or tuple for parameters" ,
705
+ ) )
706
+ }
707
+ }
708
+ }
709
+
710
+ impl < ' py > ListOrTuple < ' py > {
711
+ pub fn iter ( & self ) -> ListOrTupleIterator {
712
+ ListOrTupleIterator {
713
+ index : 0 ,
714
+ inner : self ,
715
+ }
716
+ }
717
+ }
718
+
719
+ impl < ' py > Iterator for ListOrTupleIterator < ' py > {
720
+ type Item = & ' py PyAny ;
721
+
722
+ fn next ( & mut self ) -> Option < Self :: Item > {
723
+ let rv = match self . inner {
724
+ ListOrTuple :: List ( list) => list. get_item ( self . index ) ,
725
+ ListOrTuple :: Tuple ( tuple) => tuple. get_item ( self . index ) ,
726
+ } ;
727
+
728
+ rv. ok ( ) . map ( |item| {
729
+ self . index += 1 ;
730
+ item
731
+ } )
732
+ }
733
+ }
679
734
#[ pymodule]
680
735
fn libsql ( py : Python , m : & PyModule ) -> PyResult < ( ) > {
681
736
let _ = tracing_subscriber:: fmt:: try_init ( ) ;
0 commit comments