diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index bd50cf28d6..ffe837d504 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -1,9 +1,6 @@ use crate::buffer::{BufferOptions, PyBuffer, PyBufferInternal}; -use crate::builtins::bytes::{PyBytes, PyBytesRef}; -use crate::builtins::list::{PyList, PyListRef}; -use crate::builtins::pystr::{PyStr, PyStrRef}; -use crate::builtins::pytype::PyTypeRef; use crate::builtins::slice::PySliceRef; +use crate::builtins::{PyBytes, PyBytesRef, PyList, PyListRef, PyStr, PyStrRef, PyTypeRef}; use crate::bytesinner::bytes_to_hex; use crate::common::{ borrow::{BorrowedValue, BorrowedValueMut}, @@ -12,7 +9,7 @@ use crate::common::{ rc::PyRc, }; use crate::function::{FuncArgs, OptionalArg}; -use crate::sliceable::{convert_slice, saturate_range, wrap_index, SequenceIndex}; +use crate::sliceable::{convert_slice, wrap_index, SequenceIndex}; use crate::slots::{AsBuffer, Comparable, Hashable, PyComparisonOp, SlotConstructor}; use crate::stdlib::pystruct::_struct::FormatSpec; use crate::utils::Either; @@ -22,8 +19,7 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; -use num_bigint::BigInt; -use num_traits::{One, Signed, ToPrimitive, Zero}; +use num_traits::ToPrimitive; use std::fmt::Debug; use std::ops::Deref; @@ -37,7 +33,7 @@ pub struct PyMemoryViewNewArgs { #[derive(Debug)] pub struct PyMemoryView { buffer: PyBuffer, - pub(crate) released: AtomicCell, + released: AtomicCell, // start should always less or equal to the stop // start and stop pointing to the memory index not slice index // if length is not zero than [start, stop) @@ -61,6 +57,24 @@ impl SlotConstructor for PyMemoryView { #[pyimpl(with(Hashable, Comparable, AsBuffer, SlotConstructor))] impl PyMemoryView { + #[cfg(debug_assertions)] + fn validate(self) -> Self { + let options = &self.buffer.options; + let bytes_len = options.len * options.itemsize; + let buffer_len = self.buffer.internal.obj_bytes().len(); + let t1 = self.stop - self.start == bytes_len; + let t2 = buffer_len >= self.stop; + let t3 = buffer_len >= self.start + bytes_len; + assert!(t1); + assert!(t2); + assert!(t3); + self + } + #[cfg(not(debug_assertions))] + fn validate(self) -> Self { + self + } + fn parse_format(format: &str, vm: &VirtualMachine) -> PyResult { FormatSpec::parse(format, vm) } @@ -80,7 +94,8 @@ impl PyMemoryView { step: 1, format_spec, hash: OnceCell::new(), - }) + } + .validate()) } pub fn from_buffer_range( @@ -99,7 +114,8 @@ impl PyMemoryView { step: 1, format_spec, hash: OnceCell::new(), - }) + } + .validate()) } fn as_contiguous(&self) -> Option> { @@ -238,25 +254,20 @@ impl PyMemoryView { fn getitem_by_slice(zelf: PyRef, slice: PySliceRef, vm: &VirtualMachine) -> PyResult { // slicing a memoryview return a new memoryview - let start = slice.start_index(vm)?; - let stop = slice.stop_index(vm)?; - let step = slice.step_index(vm)?.unwrap_or_else(BigInt::one); - if step.is_zero() { - return Err(vm.new_value_error("slice step cannot be zero".to_owned())); - } - let newstep: BigInt = step.clone() * zelf.step; let len = zelf.buffer.options.len; + let (range, step, is_negative_step) = convert_slice(&slice, len, vm)?; + let abs_step = step.unwrap(); + let step = if is_negative_step { + -(abs_step as isize) + } else { + abs_step as isize + }; + let newstep = step * zelf.step; let itemsize = zelf.buffer.options.itemsize; let format_spec = zelf.format_spec.clone(); - if newstep == BigInt::one() { - let range = saturate_range(&start, &stop, len); - let range = if range.end < range.start { - range.start..range.start - } else { - range - }; + if newstep == 1 { let newlen = range.end - range.start; let start = zelf.start + range.start * itemsize; let stop = zelf.start + range.end * itemsize; @@ -273,38 +284,11 @@ impl PyMemoryView { format_spec, hash: OnceCell::new(), } + .validate() .into_object(vm)); } - let (start, stop) = if step.is_negative() { - ( - stop.map(|x| { - if x == -BigInt::one() { - len + BigInt::one() - } else { - x + 1 - } - }), - start.map(|x| { - if x == -BigInt::one() { - BigInt::from(len) - } else { - x + 1 - } - }), - ) - } else { - (start, stop) - }; - - let range = saturate_range(&start, &stop, len); - let newlen = if range.end > range.start { - // overflow is not possible as dividing a positive integer - ((range.end - range.start - 1) / step.abs()) - .to_usize() - .unwrap() - + 1 - } else { + if range.start >= range.end { return Ok(PyMemoryView { buffer: zelf.buffer.clone_with_options(BufferOptions { len: 0, @@ -318,9 +302,16 @@ impl PyMemoryView { format_spec, hash: OnceCell::new(), } + .validate() .into_object(vm)); }; + // overflow is not possible as dividing a positive integer + let newlen = ((range.end - range.start - 1) / abs_step) + .to_usize() + .unwrap() + + 1; + // newlen will be 0 if step is overflowed let newstep = newstep.to_isize().unwrap_or(-1); @@ -348,6 +339,7 @@ impl PyMemoryView { format_spec, hash: OnceCell::new(), } + .validate() .into_object(vm)) } @@ -541,6 +533,7 @@ impl PyMemoryView { hash: OnceCell::new(), ..*zelf } + .validate() .into_ref(vm)) } @@ -610,6 +603,7 @@ impl PyMemoryView { hash: OnceCell::new(), ..*zelf } + .validate() .into_ref(vm)) } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 456160aa90..0d553318db 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -893,7 +893,7 @@ mod _io { // TODO: loop if write() raises an interrupt let res = vm.call_method(raw, "write", (memobj.clone(),)); - memobj.released.store(true); + memobj.release(); self.buffer = std::mem::take(&mut writebuf.data.lock()); res? @@ -1132,7 +1132,7 @@ mod _io { let res = vm.call_method(self.raw.as_ref().unwrap(), "readinto", (memobj.clone(),)); - memobj.released.store(true); + memobj.release(); std::mem::swap(v, &mut readbuf.data.lock()); res?