@@ -3,13 +3,23 @@ use pyo3::create_exception;
3
3
use pyo3:: exceptions:: PyValueError ;
4
4
use pyo3:: prelude:: * ;
5
5
use pyo3:: types:: { PyList , PyTuple } ;
6
- use std:: cell:: { OnceCell , RefCell } ;
6
+ use std:: cell:: RefCell ;
7
7
use std:: sync:: { Arc , OnceLock } ;
8
8
use std:: time:: Duration ;
9
9
use tokio:: runtime:: { Handle , Runtime } ;
10
10
11
11
const LEGACY_TRANSACTION_CONTROL : i32 = -1 ;
12
12
13
+ enum ListOrTuple < ' py > {
14
+ List ( & ' py PyList ) ,
15
+ Tuple ( & ' py PyTuple ) ,
16
+ }
17
+
18
+ struct ListOrTupleIterator < ' py > {
19
+ index : usize ,
20
+ inner : & ' py ListOrTuple < ' py >
21
+ }
22
+
13
23
fn rt ( ) -> Handle {
14
24
static RT : OnceLock < Runtime > = OnceLock :: new ( ) ;
15
25
@@ -286,7 +296,7 @@ impl Connection {
286
296
fn execute (
287
297
self_ : PyRef < ' _ , Self > ,
288
298
sql : String ,
289
- parameters : Option < & PyTuple > ,
299
+ parameters : Option < ListOrTuple > ,
290
300
) -> PyResult < Cursor > {
291
301
let cursor = Connection :: cursor ( & self_) ?;
292
302
rt ( ) . block_on ( async { execute ( & cursor, sql, parameters) . await } ) ?;
@@ -300,7 +310,7 @@ impl Connection {
300
310
) -> PyResult < Cursor > {
301
311
let cursor = Connection :: cursor ( & self_) ?;
302
312
for parameters in parameters. unwrap ( ) . iter ( ) {
303
- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
313
+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
304
314
rt ( ) . block_on ( async { execute ( & cursor, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
305
315
}
306
316
Ok ( cursor)
@@ -396,7 +406,7 @@ impl Cursor {
396
406
fn execute < ' a > (
397
407
self_ : PyRef < ' a , Self > ,
398
408
sql : String ,
399
- parameters : Option < & PyTuple > ,
409
+ parameters : Option < ListOrTuple > ,
400
410
) -> PyResult < pyo3:: PyRef < ' a , Self > > {
401
411
rt ( ) . block_on ( async { execute ( & self_, sql, parameters) . await } ) ?;
402
412
Ok ( self_)
@@ -408,7 +418,7 @@ impl Cursor {
408
418
parameters : Option < & PyList > ,
409
419
) -> PyResult < pyo3:: PyRef < ' a , Cursor > > {
410
420
for parameters in parameters. unwrap ( ) . iter ( ) {
411
- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
421
+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
412
422
rt ( ) . block_on ( async { execute ( & self_, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
413
423
}
414
424
Ok ( self_)
@@ -552,7 +562,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
552
562
Ok ( ( ) )
553
563
}
554
564
555
- async fn execute ( cursor : & Cursor , sql : String , parameters : Option < & PyTuple > ) -> PyResult < ( ) > {
565
+ async fn execute < ' py > (
566
+ cursor : & Cursor ,
567
+ sql : String ,
568
+ parameters : Option < ListOrTuple < ' py > > ,
569
+ ) -> PyResult < ( ) > {
556
570
if cursor. conn . borrow ( ) . as_ref ( ) . is_none ( ) {
557
571
return Err ( PyValueError :: new_err ( "Connection already closed" ) ) ;
558
572
}
@@ -576,7 +590,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
576
590
} else if let Ok ( value) = param. extract :: < & [ u8 ] > ( ) {
577
591
libsql_core:: Value :: Blob ( value. to_vec ( ) )
578
592
} else {
579
- return Err ( PyValueError :: new_err ( "Unsupported parameter type" ) ) ;
593
+ return Err ( PyValueError :: new_err ( format ! (
594
+ "Unsupported parameter type {}" ,
595
+ param. to_string( )
596
+ ) ) ) ;
580
597
} ;
581
598
params. push ( param) ;
582
599
}
@@ -653,6 +670,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult
653
670
654
671
create_exception ! ( libsql, Error , pyo3:: exceptions:: PyException ) ;
655
672
673
+ impl < ' py > FromPyObject < ' py > for ListOrTuple < ' py > {
674
+ fn extract ( ob : & ' py PyAny ) -> PyResult < Self > {
675
+ if let Ok ( list) = ob. downcast :: < PyList > ( ) {
676
+ Ok ( ListOrTuple :: List ( list) )
677
+ } else if let Ok ( tuple) = ob. downcast :: < PyTuple > ( ) {
678
+ Ok ( ListOrTuple :: Tuple ( tuple) )
679
+ } else {
680
+ Err ( PyValueError :: new_err (
681
+ "Expected a list or tuple for parameters" ,
682
+ ) )
683
+ }
684
+ }
685
+ }
686
+
687
+ impl < ' py > ListOrTuple < ' py > {
688
+ pub fn iter ( & self ) -> ListOrTupleIterator {
689
+ ListOrTupleIterator {
690
+ index : 0 ,
691
+ inner : self ,
692
+ }
693
+ }
694
+ }
695
+
696
+ impl < ' py > Iterator for ListOrTupleIterator < ' py > {
697
+ type Item = & ' py PyAny ;
698
+
699
+ fn next ( & mut self ) -> Option < Self :: Item > {
700
+ let rv = match self . inner {
701
+ ListOrTuple :: List ( list) => list. get_item ( self . index ) ,
702
+ ListOrTuple :: Tuple ( tuple) => tuple. get_item ( self . index ) ,
703
+ } ;
704
+
705
+ rv. ok ( ) . map ( |item| {
706
+ self . index += 1 ;
707
+ item
708
+ } )
709
+ }
710
+ }
656
711
#[ pymodule]
657
712
fn libsql ( py : Python , m : & PyModule ) -> PyResult < ( ) > {
658
713
let _ = tracing_subscriber:: fmt:: try_init ( ) ;
0 commit comments