Skip to content

Commit 111cd3b

Browse files
committed
fix memoryview cmp
1 parent 41c3e13 commit 111cd3b

File tree

2 files changed

+67
-54
lines changed

2 files changed

+67
-54
lines changed

vm/src/builtins/memory.rs

+46-41
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use super::{
2-
PyBytes, PyBytesRef, PyList, PyListRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef,
3-
PyTypeRef,
2+
PyBytes, PyBytesRef, PyListRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyTypeRef,
43
};
54
use crate::common::{
65
borrow::{BorrowedValue, BorrowedValueMut},
@@ -94,7 +93,7 @@ impl PyMemoryView {
9493
let mut zelf = PyMemoryView {
9594
buffer: ManuallyDrop::new(buffer),
9695
released: AtomicCell::new(false),
97-
start: range.start,
96+
start: 0,
9897
format_spec,
9998
desc,
10099
hash: OnceCell::new(),
@@ -363,10 +362,12 @@ impl PyMemoryView {
363362

364363
let mut bytes_mut = dest.buffer.obj_bytes_mut();
365364
let src_bytes = src.obj_bytes();
366-
dest.desc.zip_eq(&src.desc, true, |a_pos, b_pos, len| {
367-
let a_pos = (a_pos + self.start as isize) as usize;
368-
let b_pos = b_pos as usize;
369-
bytes_mut[a_pos..a_pos + len].copy_from_slice(&src_bytes[b_pos..b_pos + len]);
365+
dest.desc.zip_eq(&src.desc, true, |a_range, b_range| {
366+
let a_range = (a_range.start + self.start as isize) as usize
367+
..(a_range.end + self.start as isize) as usize;
368+
let b_range = b_range.start as usize..b_range.end as usize;
369+
bytes_mut[a_range].copy_from_slice(&src_bytes[b_range]);
370+
false
370371
});
371372

372373
Ok(())
@@ -418,8 +419,8 @@ impl PyMemoryView {
418419
}
419420

420421
fn init_len(&mut self) {
421-
let product = self.desc.dim_desc.iter().map(|x| x.0).product();
422-
self.desc.len = product;
422+
let product: usize = self.desc.dim_desc.iter().map(|x| x.0).product();
423+
self.desc.len = product * self.desc.itemsize;
423424
}
424425

425426
fn init_range(&mut self, range: Range<usize>, dim: usize) {
@@ -523,11 +524,14 @@ impl PyMemoryView {
523524
#[pymethod]
524525
fn tolist(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
525526
self.try_not_released(vm)?;
527+
let bytes = self.buffer.obj_bytes();
526528
if self.desc.ndim() == 0 {
527-
// TODO: unpack_single(view->buf, fmt)
528-
return Ok(vm.ctx.new_list(vec![]));
529+
return Ok(vm.ctx.new_list(vec![format_unpack(
530+
&self.format_spec,
531+
&bytes[..self.desc.itemsize],
532+
vm,
533+
)?]));
529534
}
530-
let bytes = self.buffer.obj_bytes();
531535
self._to_list(&bytes, 0, 0, vm)
532536
}
533537

@@ -693,15 +697,36 @@ impl PyMemoryView {
693697
return Ok(vm.bool_eq(&a_val, &b_val)?);
694698
}
695699

696-
zelf.contiguous_or_collect(|a| {
697-
other.contiguous_or_collect(|b| {
698-
// TODO: optimize cmp by format
699-
let a_list = unpack_bytes_seq_to_list(a, a_format_spec, vm)?;
700-
let b_list = unpack_bytes_seq_to_list(b, b_format_spec, vm)?;
701-
702-
vm.bool_eq(a_list.as_object(), b_list.as_object())
703-
})
704-
})
700+
// TODO: optimize cmp by format
701+
let mut ret = Ok(true);
702+
let a_bytes = zelf.buffer.obj_bytes();
703+
let b_bytes = other.obj_bytes();
704+
zelf.desc.zip_eq(&other.desc, false, |a_range, b_range| {
705+
let a_range = (a_range.start + zelf.start as isize) as usize
706+
..(a_range.end + zelf.start as isize) as usize;
707+
let b_range = b_range.start as usize..b_range.end as usize;
708+
let a_val = match format_unpack(&a_format_spec, &a_bytes[a_range], vm) {
709+
Ok(val) => val,
710+
Err(e) => {
711+
ret = Err(e);
712+
return true;
713+
}
714+
};
715+
let b_val = match format_unpack(&b_format_spec, &b_bytes[b_range], vm) {
716+
Ok(val) => val,
717+
Err(e) => {
718+
ret = Err(e);
719+
return true;
720+
}
721+
};
722+
ret = vm.bool_eq(&a_val, &b_val);
723+
if let Ok(b) = ret {
724+
!b
725+
} else {
726+
true
727+
}
728+
});
729+
ret
705730
}
706731

707732
#[pymethod(magic)]
@@ -900,26 +925,6 @@ fn format_unpack(
900925
})
901926
}
902927

903-
fn unpack_bytes_seq_to_list(
904-
bytes: &[u8],
905-
format_spec: &FormatSpec,
906-
vm: &VirtualMachine,
907-
) -> PyResult<PyListRef> {
908-
let itemsize = format_spec.size();
909-
910-
if bytes.len() % itemsize != 0 {
911-
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
912-
}
913-
914-
let len = bytes.len() / itemsize;
915-
916-
let elements: Vec<PyObjectRef> = (0..len)
917-
.map(|i| format_unpack(&format_spec, &bytes[i..i + itemsize], vm))
918-
.try_collect()?;
919-
920-
Ok(PyList::from(elements).into_ref(vm))
921-
}
922-
923928
fn is_equiv_shape(a: &BufferDescriptor, b: &BufferDescriptor) -> bool {
924929
if a.ndim() != b.ndim() {
925930
return false;

vm/src/protocol/buffer.rs

+21-13
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ impl PyBuffer {
7373
}
7474

7575
pub fn collect(&self, buf: &mut Vec<u8>) {
76-
if self.desc.is_contiguous() {
77-
buf.extend_from_slice(&self.obj_bytes());
76+
if let Some(bytes) = self.as_contiguous() {
77+
buf.extend_from_slice(&bytes);
7878
} else {
7979
let bytes = &*self.obj_bytes();
8080
self.desc.for_each_segment(true, |range| {
@@ -86,8 +86,8 @@ impl PyBuffer {
8686
pub fn contiguous_or_collect<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
8787
let borrowed;
8888
let mut collected;
89-
let v = if self.desc.is_contiguous() {
90-
borrowed = self.obj_bytes();
89+
let v = if let Some(bytes) = self.as_contiguous() {
90+
borrowed = bytes;
9191
&*borrowed
9292
} else {
9393
collected = vec![];
@@ -151,7 +151,7 @@ impl Drop for PyBuffer {
151151
#[derive(Debug, Clone)]
152152
pub struct BufferDescriptor {
153153
/// product(shape) * itemsize
154-
/// NOT the bytes length if buffer is discontiguous
154+
/// bytes length, but not the length for obj_bytes() even is contiguous
155155
pub len: usize,
156156
pub readonly: bool,
157157
pub itemsize: usize,
@@ -296,15 +296,13 @@ impl BufferDescriptor {
296296
}
297297
}
298298

299+
/// zip two BufferDescriptor with the same shape
299300
pub fn zip_eq<F>(&self, other: &Self, try_conti: bool, mut f: F)
300301
where
301-
F: FnMut(isize, isize, usize),
302+
F: FnMut(Range<isize>, Range<isize>) -> bool,
302303
{
303-
debug_assert_eq!(self.itemsize, other.itemsize);
304-
debug_assert_eq!(self.len, other.len);
305-
306304
if self.ndim() == 0 {
307-
f(0, 0, self.itemsize);
305+
f(0..self.itemsize as isize, 0..other.itemsize as isize);
308306
return;
309307
}
310308
if try_conti && self.is_last_dim_contiguous() {
@@ -322,19 +320,29 @@ impl BufferDescriptor {
322320
dim: usize,
323321
f: &mut F,
324322
) where
325-
F: FnMut(isize, isize, usize),
323+
F: FnMut(Range<isize>, Range<isize>) -> bool,
326324
{
327325
let (shape, a_stride, a_suboffset) = self.dim_desc[dim];
328326
let (_b_shape, b_stride, b_suboffset) = other.dim_desc[dim];
329327
debug_assert_eq!(shape, _b_shape);
330328
if dim + 1 == self.ndim() {
331329
if CONTI {
332-
f(a_index, b_index, shape * self.itemsize);
330+
if f(
331+
a_index..a_index + (shape * self.itemsize) as isize,
332+
b_index..b_index + (shape * other.itemsize) as isize,
333+
) {
334+
return;
335+
}
333336
} else {
334337
for _ in 0..shape {
335338
let a_pos = a_index + a_suboffset;
336339
let b_pos = b_index + b_suboffset;
337-
f(a_pos, b_pos, self.itemsize);
340+
if f(
341+
a_pos..a_pos + self.itemsize as isize,
342+
b_pos..b_pos + other.itemsize as isize,
343+
) {
344+
return;
345+
}
338346
a_index += a_stride;
339347
b_index += b_stride;
340348
}

0 commit comments

Comments
 (0)