1
- use bytes:: BytesMut ;
1
+ use bytes:: { Buf , BytesMut } ;
2
2
use deadpool_postgres:: { Object , Pool } ;
3
3
use futures_util:: pin_mut;
4
+ use postgres_types:: ToSql ;
4
5
use pyo3:: { buffer:: PyBuffer , pyclass, pymethods, Py , PyAny , PyErr , Python } ;
5
6
use std:: { collections:: HashSet , sync:: Arc , vec} ;
6
- use tokio_postgres:: binary_copy:: BinaryCopyInWriter ;
7
+ use tokio_postgres:: { binary_copy:: BinaryCopyInWriter , Client , CopyInSink , Row , Statement , ToStatement } ;
7
8
8
9
use crate :: {
9
10
exceptions:: rust_errors:: { RustPSQLDriverError , RustPSQLDriverPyResult } ,
@@ -19,110 +20,115 @@ use super::{
19
20
transaction_options:: { IsolationLevel , ReadVariant , SynchronousCommit } ,
20
21
} ;
21
22
22
- /// Format OPTS parameter for Postgres COPY command.
23
- ///
24
- /// # Errors
25
- /// May return Err Result if cannot format parameter.
26
- #[ allow( clippy:: too_many_arguments) ]
27
- pub fn _format_copy_opts (
28
- format : Option < String > ,
29
- freeze : Option < bool > ,
30
- delimiter : Option < String > ,
31
- null : Option < String > ,
32
- header : Option < String > ,
33
- quote : Option < String > ,
34
- escape : Option < String > ,
35
- force_quote : Option < Py < PyAny > > ,
36
- force_not_null : Option < Vec < String > > ,
37
- force_null : Option < Vec < String > > ,
38
- encoding : Option < String > ,
39
- ) -> RustPSQLDriverPyResult < String > {
40
- let mut opts: Vec < String > = vec ! [ ] ;
41
-
42
- if let Some ( format) = format {
43
- opts. push ( format ! ( "FORMAT {format}" ) ) ;
44
- }
23
+ pub enum InnerConnection {
24
+ PoolConn ( Object ) ,
25
+ SingleConn ( Client ) ,
26
+ }
45
27
46
- if let Some ( freeze) = freeze {
47
- if freeze {
48
- opts. push ( "FREEZE TRUE" . into ( ) ) ;
49
- } else {
50
- opts. push ( "FREEZE FALSE" . into ( ) ) ;
28
+ impl InnerConnection {
29
+ pub async fn prepare_cached (
30
+ & self ,
31
+ query : & str
32
+ ) -> RustPSQLDriverPyResult < Statement > {
33
+ match self {
34
+ InnerConnection :: PoolConn ( pconn) => {
35
+ return Ok ( pconn. prepare_cached ( query) . await ?)
36
+ }
37
+ InnerConnection :: SingleConn ( sconn) => {
38
+ return Ok ( sconn. prepare ( query) . await ?)
39
+ }
51
40
}
52
41
}
53
-
54
- if let Some ( delimiter) = delimiter {
55
- opts. push ( format ! ( "DELIMITER {delimiter}" ) ) ;
56
- }
57
-
58
- if let Some ( null) = null {
59
- opts. push ( format ! ( "NULL {}" , quote_ident( & null) ) ) ;
60
- }
61
-
62
- if let Some ( header) = header {
63
- opts. push ( format ! ( "HEADER {header}" ) ) ;
64
- }
65
-
66
- if let Some ( quote) = quote {
67
- opts. push ( format ! ( "QUOTE {quote}" ) ) ;
68
- }
69
-
70
- if let Some ( escape) = escape {
71
- opts. push ( format ! ( "ESCAPE {escape}" ) ) ;
72
- }
73
-
74
- if let Some ( force_quote) = force_quote {
75
- let boolean_force_quote: Result < bool , PyErr > =
76
- Python :: with_gil ( |gil| force_quote. extract :: < bool > ( gil) ) ;
77
-
78
- if let Ok ( force_quote) = boolean_force_quote {
79
- if force_quote {
80
- opts. push ( "FORCE_QUOTE *" . into ( ) ) ;
42
+
43
+ pub async fn query < T > (
44
+ & self ,
45
+ statement : & T ,
46
+ params : & [ & ( dyn ToSql + Sync ) ] ,
47
+ ) -> RustPSQLDriverPyResult < Vec < Row > >
48
+ where T : ?Sized + ToStatement {
49
+ match self {
50
+ InnerConnection :: PoolConn ( pconn) => {
51
+ return Ok ( pconn. query ( statement, params) . await ?)
81
52
}
82
- } else {
83
- let sequence_force_quote: Result < Vec < String > , PyErr > =
84
- Python :: with_gil ( |gil| force_quote. extract :: < Vec < String > > ( gil) ) ;
85
-
86
- if let Ok ( force_quote) = sequence_force_quote {
87
- opts. push ( format ! ( "FORCE_QUOTE ({})" , force_quote. join( ", " ) ) ) ;
53
+ InnerConnection :: SingleConn ( sconn) => {
54
+ return Ok ( sconn. query ( statement, params) . await ?)
88
55
}
89
-
90
- return Err ( RustPSQLDriverError :: PyToRustValueConversionError (
91
- "force_quote parameter must be boolean or sequence of str's." . into ( ) ,
92
- ) ) ;
93
56
}
94
57
}
95
58
96
- if let Some ( force_not_null) = force_not_null {
97
- opts. push ( format ! ( "FORCE_NOT_NULL ({})" , force_not_null. join( ", " ) ) ) ;
98
- }
99
-
100
- if let Some ( force_null) = force_null {
101
- opts. push ( format ! ( "FORCE_NULL ({})" , force_null. join( ", " ) ) ) ;
59
+ pub async fn batch_execute ( & self , query : & str ) -> RustPSQLDriverPyResult < ( ) > {
60
+ match self {
61
+ InnerConnection :: PoolConn ( pconn) => {
62
+ return Ok ( pconn. batch_execute ( query) . await ?)
63
+ }
64
+ InnerConnection :: SingleConn ( sconn) => {
65
+ return Ok ( sconn. batch_execute ( query) . await ?)
66
+ }
67
+ }
102
68
}
103
69
104
- if let Some ( encoding) = encoding {
105
- opts. push ( format ! ( "ENCODING {}" , quote_ident( & encoding) ) ) ;
70
+ pub async fn query_one < T > (
71
+ & self ,
72
+ statement : & T ,
73
+ params : & [ & ( dyn ToSql + Sync ) ] ,
74
+ ) -> RustPSQLDriverPyResult < Row >
75
+ where T : ?Sized + ToStatement
76
+ {
77
+ match self {
78
+ InnerConnection :: PoolConn ( pconn) => {
79
+ return Ok ( pconn. query_one ( statement, params) . await ?)
80
+ }
81
+ InnerConnection :: SingleConn ( sconn) => {
82
+ return Ok ( sconn. query_one ( statement, params) . await ?)
83
+ }
84
+ }
106
85
}
107
86
108
- if opts. is_empty ( ) {
109
- Ok ( String :: new ( ) )
110
- } else {
111
- Ok ( format ! ( "({})" , opts. join( ", " ) ) )
87
+ pub async fn copy_in < T , U > (
88
+ & self ,
89
+ statement : & T
90
+ ) -> RustPSQLDriverPyResult < CopyInSink < U > >
91
+ where
92
+ T : ?Sized + ToStatement ,
93
+ U : Buf + ' static + Send
94
+ {
95
+ match self {
96
+ InnerConnection :: PoolConn ( pconn) => {
97
+ return Ok ( pconn. copy_in ( statement) . await ?)
98
+ }
99
+ InnerConnection :: SingleConn ( sconn) => {
100
+ return Ok ( sconn. copy_in ( statement) . await ?)
101
+ }
102
+ }
112
103
}
113
104
}
114
105
115
106
#[ pyclass( subclass) ]
107
+ #[ derive( Clone ) ]
116
108
pub struct Connection {
117
- db_client : Option < Arc < Object > > ,
109
+ db_client : Option < Arc < InnerConnection > > ,
118
110
db_pool : Option < Pool > ,
119
111
}
120
112
121
113
impl Connection {
122
114
#[ must_use]
123
- pub fn new ( db_client : Option < Arc < Object > > , db_pool : Option < Pool > ) -> Self {
115
+ pub fn new ( db_client : Option < Arc < InnerConnection > > , db_pool : Option < Pool > ) -> Self {
124
116
Connection { db_client, db_pool }
125
117
}
118
+
119
+ pub fn db_client ( & self ) -> Option < Arc < InnerConnection > > {
120
+ return self . db_client . clone ( )
121
+ }
122
+
123
+ pub fn db_pool ( & self ) -> Option < Pool > {
124
+ return self . db_pool . clone ( )
125
+ }
126
+ }
127
+
128
+ impl Default for Connection {
129
+ fn default ( ) -> Self {
130
+ Connection :: new ( None , None )
131
+ }
126
132
}
127
133
128
134
#[ pymethods]
@@ -145,7 +151,7 @@ impl Connection {
145
151
. await ??;
146
152
pyo3:: Python :: with_gil ( |gil| {
147
153
let mut self_ = self_. borrow_mut ( gil) ;
148
- self_. db_client = Some ( Arc :: new ( db_connection) ) ;
154
+ self_. db_client = Some ( Arc :: new ( InnerConnection :: PoolConn ( db_connection) ) ) ;
149
155
} ) ;
150
156
return Ok ( self_) ;
151
157
}
0 commit comments