diff --git a/tests/snippets/test_csv.py b/tests/snippets/test_csv.py index 4d687cfbd7..6ba66d30f7 100644 --- a/tests/snippets/test_csv.py +++ b/tests/snippets/test_csv.py @@ -1,3 +1,5 @@ +from testutils import assert_raises + import csv for row in csv.reader(['one,two,three']): @@ -21,3 +23,23 @@ def f(): assert six == 'six' f() + +def test_delim(): + iter = ['one|two|three', 'four|five|six'] + reader = csv.reader(iter, delimiter='|') + + [one,two,three] = next(reader) + [four,five,six] = next(reader) + + assert one == 'one' + assert two == 'two' + assert three == 'three' + assert four == 'four' + assert five == 'five' + assert six == 'six' + + with assert_raises(TypeError): + iter = ['one,,two,,three'] + csv.reader(iter, delimiter=',,') + +test_delim() diff --git a/vm/src/stdlib/csv.rs b/vm/src/stdlib/csv.rs index c00fb8890d..d937379bb2 100644 --- a/vm/src/stdlib/csv.rs +++ b/vm/src/stdlib/csv.rs @@ -4,8 +4,10 @@ use std::fmt::{self, Debug, Formatter}; use csv as rust_csv; use itertools::join; +use crate::function::PyFuncArgs; + use crate::obj::objiter; -use crate::obj::objstr::PyString; +use crate::obj::objstr::{self, PyString}; use crate::obj::objtype::PyClassRef; use crate::pyobject::{IntoPyObject, TryFromObject, TypeProtocol}; use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue}; @@ -20,8 +22,58 @@ pub enum QuoteStyle { QuoteNone, } -pub fn build_reader(iterable: PyIterable, vm: &VirtualMachine) -> PyResult { - Reader::new(iterable).into_ref(vm).into_pyobject(vm) +struct ReaderOption { + delimiter: u8, + quotechar: u8, +} + +impl ReaderOption { + fn new(args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { + let delimiter = { + let bytes = args + .get_optional_kwarg("delimiter") + .map_or(",".to_string(), |pyobj| objstr::get_value(&pyobj)) + .into_bytes(); + + match bytes.len() { + 1 => bytes[0], + _ => { + let msg = r#""delimiter" must be a 1-character string"#; + return Err(vm.new_type_error(msg.to_string())); + } + } + }; + + let quotechar = { + let bytes = args + .get_optional_kwarg("quotechar") + .map_or("\"".to_string(), |pyobj| objstr::get_value(&pyobj)) + .into_bytes(); + + match bytes.len() { + 1 => bytes[0], + _ => { + let msg = r#""quotechar" must be a 1-character string"#; + return Err(vm.new_type_error(msg.to_string())); + } + } + }; + + Ok(ReaderOption { + delimiter, + quotechar, + }) + } +} + +pub fn build_reader( + iterable: PyIterable, + args: PyFuncArgs, + vm: &VirtualMachine, +) -> PyResult { + let config = ReaderOption::new(args, vm)?; + + Reader::new(iterable, config).into_ref(vm).into_pyobject(vm) } fn into_strings(iterable: &PyIterable, vm: &VirtualMachine) -> PyResult> { @@ -46,17 +98,17 @@ type MemIO = std::io::Cursor>; #[allow(dead_code)] enum ReadState { - PyIter(PyIterable), + PyIter(PyIterable, ReaderOption), CsvIter(rust_csv::StringRecordsIntoIter), } impl ReadState { - fn new(iter: PyIterable) -> Self { - ReadState::PyIter(iter) + fn new(iter: PyIterable, config: ReaderOption) -> Self { + ReadState::PyIter(iter, config) } fn cast_to_reader(&mut self, vm: &VirtualMachine) -> PyResult<()> { - if let ReadState::PyIter(ref iterable) = self { + if let ReadState::PyIter(ref iterable, ref config) = self { let lines = into_strings(iterable, vm)?; let contents = join(lines, "\n"); @@ -64,6 +116,8 @@ impl ReadState { let reader = MemIO::new(bytes); let csv_iter = rust_csv::ReaderBuilder::new() + .delimiter(config.delimiter) + .quote(config.quotechar) .has_headers(false) .from_reader(reader) .into_records(); @@ -92,8 +146,8 @@ impl PyValue for Reader { } impl Reader { - fn new(iter: PyIterable) -> Self { - let state = RefCell::new(ReadState::new(iter)); + fn new(iter: PyIterable, config: ReaderOption) -> Self { + let state = RefCell::new(ReadState::new(iter, config)); Reader { state } } } @@ -121,7 +175,7 @@ impl Reader { .collect::>>()?; Ok(vm.ctx.new_list(iter)) } - Err(_) => { + Err(_err) => { let msg = String::from("Decode Error"); let decode_error = vm.new_unicode_decode_error(msg); Err(decode_error) @@ -136,9 +190,9 @@ impl Reader { } } -fn csv_reader(fp: PyObjectRef, vm: &VirtualMachine) -> PyResult { +fn csv_reader(fp: PyObjectRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { if let Ok(iterable) = PyIterable::::try_from_object(vm, fp) { - build_reader(iterable, vm) + build_reader(iterable, args, vm) } else { Err(vm.new_type_error("argument 1 must be an iterator".to_string())) } @@ -156,13 +210,13 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { ); py_module!(vm, "_csv", { - "reader" => ctx.new_rustfunc(csv_reader), - "Reader" => reader_type, - "Error" => error, - // constants - "QUOTE_MINIMAL" => ctx.new_int(QuoteStyle::QuoteMinimal as i32), - "QUOTE_ALL" => ctx.new_int(QuoteStyle::QuoteAll as i32), - "QUOTE_NONNUMERIC" => ctx.new_int(QuoteStyle::QuoteNonnumeric as i32), - "QUOTE_NONE" => ctx.new_int(QuoteStyle::QuoteNone as i32), + "reader" => ctx.new_rustfunc(csv_reader), + "Reader" => reader_type, + "Error" => error, + // constants + "QUOTE_MINIMAL" => ctx.new_int(QuoteStyle::QuoteMinimal as i32), + "QUOTE_ALL" => ctx.new_int(QuoteStyle::QuoteAll as i32), + "QUOTE_NONNUMERIC" => ctx.new_int(QuoteStyle::QuoteNonnumeric as i32), + "QUOTE_NONE" => ctx.new_int(QuoteStyle::QuoteNone as i32), }) }