Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions vm/src/stdlib/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::pyobject::{
use crate::vm::VirtualMachine;

fn byte_count(bytes: OptionalOption<i64>) -> i64 {
bytes.flatten().unwrap_or(-1 as i64)
bytes.flatten().unwrap_or(-1)
}
fn os_err(vm: &VirtualMachine, err: io::Error) -> PyBaseExceptionRef {
#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))]
Expand Down Expand Up @@ -83,28 +83,32 @@ impl BufferedIO {
}

//Read k bytes from the object and return.
fn read(&mut self, bytes: i64) -> Option<Vec<u8>> {
let mut buffer = Vec::new();

fn read(&mut self, bytes: Option<i64>) -> Option<Vec<u8>> {
//for a defined number of bytes, i.e. bytes != -1
if bytes >= 0 {
let mut handle = self.cursor.clone().take(bytes as u64);
//read handle into buffer

if handle.read_to_end(&mut buffer).is_err() {
return None;
match bytes.and_then(|v| v.to_usize()) {
Some(bytes) => {
let mut buffer = unsafe {
// Do not move or edit any part of this block without a safety validation.
// `set_len` is guaranteed to be safe only when the new length is less than or equal to the capacity
let mut buffer = Vec::with_capacity(bytes);
buffer.set_len(bytes);
buffer
};
//read handle into buffer
self.cursor
.read_exact(&mut buffer)
.map_or(None, |_| Some(buffer))
}
//the take above consumes the struct value
//we add this back in with the takes into_inner method
self.cursor = handle.into_inner();
} else {
//read handle into buffer
if self.cursor.read_to_end(&mut buffer).is_err() {
return None;
None => {
let mut buffer = Vec::new();
//read handle into buffer
if self.cursor.read_to_end(&mut buffer).is_err() {
None
} else {
Some(buffer)
}
}
};

Some(buffer)
}
}

fn tell(&self) -> u64 {
Expand Down Expand Up @@ -209,7 +213,7 @@ impl PyStringIORef {
//If k is undefined || k == -1, then we read all bytes until the end of the file.
//This also increments the stream position by the value of k
fn read(self, bytes: OptionalOption<i64>, vm: &VirtualMachine) -> PyResult {
let data = match self.buffer(vm)?.read(byte_count(bytes)) {
let data = match self.buffer(vm)?.read(bytes.flatten()) {
Some(value) => value,
None => Vec::new(),
};
Expand Down Expand Up @@ -263,11 +267,12 @@ fn string_io_new(
_args: StringIOArgs,
vm: &VirtualMachine,
) -> PyResult<PyStringIORef> {
let flatten = object.flatten();
let input = flatten.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec());
let raw_bytes = object
.flatten()
.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec());

PyStringIO {
buffer: PyRwLock::new(BufferedIO::new(Cursor::new(input))),
buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))),
closed: AtomicCell::new(false),
}
.into_ref_with_type(vm, cls)
Expand Down Expand Up @@ -312,7 +317,7 @@ impl PyBytesIORef {
//If k is undefined || k == -1, then we read all bytes until the end of the file.
//This also increments the stream position by the value of k
fn read(self, bytes: OptionalOption<i64>, vm: &VirtualMachine) -> PyResult {
match self.buffer(vm)?.read(byte_count(bytes)) {
match self.buffer(vm)?.read(bytes.flatten()) {
Some(value) => Ok(vm.ctx.new_bytes(value)),
None => Err(vm.new_value_error("Error Retrieving Value".to_owned())),
}
Expand Down Expand Up @@ -363,10 +368,9 @@ fn bytes_io_new(
object: OptionalArg<Option<PyBytesRef>>,
vm: &VirtualMachine,
) -> PyResult<PyBytesIORef> {
let raw_bytes = match object {
OptionalArg::Present(Some(ref input)) => input.get_value().to_vec(),
_ => vec![],
};
let raw_bytes = object
.flatten()
.map_or_else(Vec::new, |input| input.get_value().to_vec());

PyBytesIO {
buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))),
Expand Down Expand Up @@ -1446,7 +1450,7 @@ mod tests {
cursor: Cursor::new(data.clone()),
};

assert_eq!(buffered.read(bytes).unwrap(), data);
assert_eq!(buffered.read(Some(bytes)).unwrap(), data);
}

#[test]
Expand All @@ -1458,7 +1462,7 @@ mod tests {
};

assert_eq!(buffered.seek(SeekFrom::Start(count)).unwrap(), count);
assert_eq!(buffered.read(count.clone() as i64).unwrap(), vec![3, 4]);
assert_eq!(buffered.read(Some(count as i64)).unwrap(), vec![3, 4]);
}

#[test]
Expand Down