@@ -4,8 +4,10 @@ use std::fmt::{self, Debug, Formatter};
4
4
use csv as rust_csv;
5
5
use itertools:: join;
6
6
7
+ use crate :: function:: PyFuncArgs ;
8
+
7
9
use crate :: obj:: objiter;
8
- use crate :: obj:: objstr:: PyString ;
10
+ use crate :: obj:: objstr:: { self , PyString } ;
9
11
use crate :: obj:: objtype:: PyClassRef ;
10
12
use crate :: pyobject:: { IntoPyObject , TryFromObject , TypeProtocol } ;
11
13
use crate :: pyobject:: { PyClassImpl , PyIterable , PyObjectRef , PyRef , PyResult , PyValue } ;
@@ -20,8 +22,58 @@ pub enum QuoteStyle {
20
22
QuoteNone ,
21
23
}
22
24
23
- pub fn build_reader ( iterable : PyIterable < PyObjectRef > , vm : & VirtualMachine ) -> PyResult {
24
- Reader :: new ( iterable) . into_ref ( vm) . into_pyobject ( vm)
25
+ struct ReaderOption {
26
+ delimiter : u8 ,
27
+ quotechar : u8 ,
28
+ }
29
+
30
+ impl ReaderOption {
31
+ fn new ( args : PyFuncArgs , vm : & VirtualMachine ) -> PyResult < Self > {
32
+ let delimiter = {
33
+ let bytes = args
34
+ . get_optional_kwarg ( "delimiter" )
35
+ . map_or ( "," . to_string ( ) , |pyobj| objstr:: get_value ( & pyobj) )
36
+ . into_bytes ( ) ;
37
+
38
+ match bytes. len ( ) {
39
+ 1 => bytes[ 0 ] ,
40
+ _ => {
41
+ let msg = r#""delimiter" must be a 1-character string"# ;
42
+ return Err ( vm. new_type_error ( msg. to_string ( ) ) ) ;
43
+ }
44
+ }
45
+ } ;
46
+
47
+ let quotechar = {
48
+ let bytes = args
49
+ . get_optional_kwarg ( "quotechar" )
50
+ . map_or ( "\" " . to_string ( ) , |pyobj| objstr:: get_value ( & pyobj) )
51
+ . into_bytes ( ) ;
52
+
53
+ match bytes. len ( ) {
54
+ 1 => bytes[ 0 ] ,
55
+ _ => {
56
+ let msg = r#""quotechar" must be a 1-character string"# ;
57
+ return Err ( vm. new_type_error ( msg. to_string ( ) ) ) ;
58
+ }
59
+ }
60
+ } ;
61
+
62
+ Ok ( ReaderOption {
63
+ delimiter,
64
+ quotechar,
65
+ } )
66
+ }
67
+ }
68
+
69
+ pub fn build_reader (
70
+ iterable : PyIterable < PyObjectRef > ,
71
+ args : PyFuncArgs ,
72
+ vm : & VirtualMachine ,
73
+ ) -> PyResult {
74
+ let config = ReaderOption :: new ( args, vm) ?;
75
+
76
+ Reader :: new ( iterable, config) . into_ref ( vm) . into_pyobject ( vm)
25
77
}
26
78
27
79
fn into_strings ( iterable : & PyIterable < PyObjectRef > , vm : & VirtualMachine ) -> PyResult < Vec < String > > {
@@ -46,24 +98,26 @@ type MemIO = std::io::Cursor<Vec<u8>>;
46
98
47
99
#[ allow( dead_code) ]
48
100
enum ReadState {
49
- PyIter ( PyIterable < PyObjectRef > ) ,
101
+ PyIter ( PyIterable < PyObjectRef > , ReaderOption ) ,
50
102
CsvIter ( rust_csv:: StringRecordsIntoIter < MemIO > ) ,
51
103
}
52
104
53
105
impl ReadState {
54
- fn new ( iter : PyIterable ) -> Self {
55
- ReadState :: PyIter ( iter)
106
+ fn new ( iter : PyIterable , config : ReaderOption ) -> Self {
107
+ ReadState :: PyIter ( iter, config )
56
108
}
57
109
58
110
fn cast_to_reader ( & mut self , vm : & VirtualMachine ) -> PyResult < ( ) > {
59
- if let ReadState :: PyIter ( ref iterable) = self {
111
+ if let ReadState :: PyIter ( ref iterable, ref config ) = self {
60
112
let lines = into_strings ( iterable, vm) ?;
61
113
let contents = join ( lines, "\n " ) ;
62
114
63
115
let bytes = Vec :: from ( contents. as_bytes ( ) ) ;
64
116
let reader = MemIO :: new ( bytes) ;
65
117
66
118
let csv_iter = rust_csv:: ReaderBuilder :: new ( )
119
+ . delimiter ( config. delimiter )
120
+ . quote ( config. quotechar )
67
121
. has_headers ( false )
68
122
. from_reader ( reader)
69
123
. into_records ( ) ;
@@ -92,8 +146,8 @@ impl PyValue for Reader {
92
146
}
93
147
94
148
impl Reader {
95
- fn new ( iter : PyIterable < PyObjectRef > ) -> Self {
96
- let state = RefCell :: new ( ReadState :: new ( iter) ) ;
149
+ fn new ( iter : PyIterable < PyObjectRef > , config : ReaderOption ) -> Self {
150
+ let state = RefCell :: new ( ReadState :: new ( iter, config ) ) ;
97
151
Reader { state }
98
152
}
99
153
}
@@ -121,7 +175,7 @@ impl Reader {
121
175
. collect :: < PyResult < Vec < _ > > > ( ) ?;
122
176
Ok ( vm. ctx . new_list ( iter) )
123
177
}
124
- Err ( _ ) => {
178
+ Err ( _err ) => {
125
179
let msg = String :: from ( "Decode Error" ) ;
126
180
let decode_error = vm. new_unicode_decode_error ( msg) ;
127
181
Err ( decode_error)
@@ -136,9 +190,9 @@ impl Reader {
136
190
}
137
191
}
138
192
139
- fn csv_reader ( fp : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
193
+ fn csv_reader ( fp : PyObjectRef , args : PyFuncArgs , vm : & VirtualMachine ) -> PyResult {
140
194
if let Ok ( iterable) = PyIterable :: < PyObjectRef > :: try_from_object ( vm, fp) {
141
- build_reader ( iterable, vm)
195
+ build_reader ( iterable, args , vm)
142
196
} else {
143
197
Err ( vm. new_type_error ( "argument 1 must be an iterator" . to_string ( ) ) )
144
198
}
@@ -156,13 +210,13 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
156
210
) ;
157
211
158
212
py_module ! ( vm, "_csv" , {
159
- "reader" => ctx. new_rustfunc( csv_reader) ,
160
- "Reader" => reader_type,
161
- "Error" => error,
162
- // constants
163
- "QUOTE_MINIMAL" => ctx. new_int( QuoteStyle :: QuoteMinimal as i32 ) ,
164
- "QUOTE_ALL" => ctx. new_int( QuoteStyle :: QuoteAll as i32 ) ,
165
- "QUOTE_NONNUMERIC" => ctx. new_int( QuoteStyle :: QuoteNonnumeric as i32 ) ,
166
- "QUOTE_NONE" => ctx. new_int( QuoteStyle :: QuoteNone as i32 ) ,
213
+ "reader" => ctx. new_rustfunc( csv_reader) ,
214
+ "Reader" => reader_type,
215
+ "Error" => error,
216
+ // constants
217
+ "QUOTE_MINIMAL" => ctx. new_int( QuoteStyle :: QuoteMinimal as i32 ) ,
218
+ "QUOTE_ALL" => ctx. new_int( QuoteStyle :: QuoteAll as i32 ) ,
219
+ "QUOTE_NONNUMERIC" => ctx. new_int( QuoteStyle :: QuoteNonnumeric as i32 ) ,
220
+ "QUOTE_NONE" => ctx. new_int( QuoteStyle :: QuoteNone as i32 ) ,
167
221
} )
168
222
}
0 commit comments