diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 18112ecb95..a2acff91de 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -26,7 +26,7 @@ use crate::pyobject::{ use crate::vm::VirtualMachine; fn byte_count(bytes: OptionalOption) -> i64 { - bytes.flatten().unwrap_or(-1 as i64) + bytes.flatten().unwrap_or(-1) } fn os_err(vm: &VirtualMachine, err: io::Error) -> PyBaseExceptionRef { #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] @@ -83,28 +83,32 @@ impl BufferedIO { } //Read k bytes from the object and return. - fn read(&mut self, bytes: i64) -> Option> { - let mut buffer = Vec::new(); - + fn read(&mut self, bytes: Option) -> Option> { //for a defined number of bytes, i.e. bytes != -1 - if bytes >= 0 { - let mut handle = self.cursor.clone().take(bytes as u64); - //read handle into buffer - - if handle.read_to_end(&mut buffer).is_err() { - return None; + match bytes.and_then(|v| v.to_usize()) { + Some(bytes) => { + let mut buffer = unsafe { + // Do not move or edit any part of this block without a safety validation. + // `set_len` is guaranteed to be safe only when the new length is less than or equal to the capacity + let mut buffer = Vec::with_capacity(bytes); + buffer.set_len(bytes); + buffer + }; + //read handle into buffer + self.cursor + .read_exact(&mut buffer) + .map_or(None, |_| Some(buffer)) } - //the take above consumes the struct value - //we add this back in with the takes into_inner method - self.cursor = handle.into_inner(); - } else { - //read handle into buffer - if self.cursor.read_to_end(&mut buffer).is_err() { - return None; + None => { + let mut buffer = Vec::new(); + //read handle into buffer + if self.cursor.read_to_end(&mut buffer).is_err() { + None + } else { + Some(buffer) + } } - }; - - Some(buffer) + } } fn tell(&self) -> u64 { @@ -209,7 +213,7 @@ impl PyStringIORef { //If k is undefined || k == -1, then we read all bytes until the end of the file. //This also increments the stream position by the value of k fn read(self, bytes: OptionalOption, vm: &VirtualMachine) -> PyResult { - let data = match self.buffer(vm)?.read(byte_count(bytes)) { + let data = match self.buffer(vm)?.read(bytes.flatten()) { Some(value) => value, None => Vec::new(), }; @@ -263,11 +267,12 @@ fn string_io_new( _args: StringIOArgs, vm: &VirtualMachine, ) -> PyResult { - let flatten = object.flatten(); - let input = flatten.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec()); + let raw_bytes = object + .flatten() + .map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec()); PyStringIO { - buffer: PyRwLock::new(BufferedIO::new(Cursor::new(input))), + buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) @@ -312,7 +317,7 @@ impl PyBytesIORef { //If k is undefined || k == -1, then we read all bytes until the end of the file. //This also increments the stream position by the value of k fn read(self, bytes: OptionalOption, vm: &VirtualMachine) -> PyResult { - match self.buffer(vm)?.read(byte_count(bytes)) { + match self.buffer(vm)?.read(bytes.flatten()) { Some(value) => Ok(vm.ctx.new_bytes(value)), None => Err(vm.new_value_error("Error Retrieving Value".to_owned())), } @@ -363,10 +368,9 @@ fn bytes_io_new( object: OptionalArg>, vm: &VirtualMachine, ) -> PyResult { - let raw_bytes = match object { - OptionalArg::Present(Some(ref input)) => input.get_value().to_vec(), - _ => vec![], - }; + let raw_bytes = object + .flatten() + .map_or_else(Vec::new, |input| input.get_value().to_vec()); PyBytesIO { buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), @@ -1446,7 +1450,7 @@ mod tests { cursor: Cursor::new(data.clone()), }; - assert_eq!(buffered.read(bytes).unwrap(), data); + assert_eq!(buffered.read(Some(bytes)).unwrap(), data); } #[test] @@ -1458,7 +1462,7 @@ mod tests { }; assert_eq!(buffered.seek(SeekFrom::Start(count)).unwrap(), count); - assert_eq!(buffered.read(count.clone() as i64).unwrap(), vec![3, 4]); + assert_eq!(buffered.read(Some(count as i64)).unwrap(), vec![3, 4]); } #[test]