Skip to content

Commit 8a0b2d9

Browse files
committed
update memoryview eq
1 parent cb51ca8 commit 8a0b2d9

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

vm/src/obj/objmemory.rs

+51-17
Original file line numberDiff line numberDiff line change
@@ -558,17 +558,7 @@ impl PyMemoryView {
558558
let bytes = &*zelf.obj_bytes();
559559
let elements: Vec<PyObjectRef> = (0..zelf.options.len)
560560
.map(|i| zelf.get_pos(i as isize).unwrap())
561-
.map(|i| {
562-
zelf.format_spec
563-
.unpack(&bytes[i..i + zelf.options.itemsize], vm)
564-
.map(|x| {
565-
if x.len() == 1 {
566-
x.fast_getitem(0)
567-
} else {
568-
x.into_object()
569-
}
570-
})
571-
})
561+
.map(|i| format_unpack(&zelf.format_spec, &bytes[i..i + zelf.options.itemsize], vm))
572562
.try_collect()?;
573563

574564
Ok(PyList::from(elements).into_ref(vm))
@@ -634,13 +624,15 @@ impl PyMemoryView {
634624
return Ok(false);
635625
}
636626

637-
let options_cmp = |a: &BufferOptions, b: &BufferOptions| -> bool {
638-
a.len == b.len && a.itemsize == b.itemsize
639-
};
640-
641627
let other = try_buffer_from_object(vm, other)?;
642628

643-
if !options_cmp(&zelf.options, &other.get_options()) {
629+
let a_options = &zelf.options;
630+
let b_options = &*other.get_options();
631+
632+
if a_options.len != b_options.len
633+
|| a_options.ndim != b_options.ndim
634+
|| a_options.shape != b_options.shape
635+
{
644636
return Ok(false);
645637
}
646638

@@ -669,7 +661,14 @@ impl PyMemoryView {
669661
}
670662
};
671663

672-
Ok(a == b)
664+
if a_options.format == b_options.format {
665+
Ok(a == b)
666+
} else {
667+
let a_list = unpack_bytes_seq_to_list(a, &a_options.format, vm)?;
668+
let b_list = unpack_bytes_seq_to_list(b, &b_options.format, vm)?;
669+
670+
Ok(vm.bool_eq(a_list.as_object(), b_list.as_object())?)
671+
}
673672
}
674673
}
675674

@@ -788,3 +787,38 @@ pub fn try_buffer_from_object(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResul
788787
obj_cls.name
789788
)))
790789
}
790+
791+
fn format_unpack(
792+
format_spec: &FormatSpec,
793+
bytes: &[u8],
794+
vm: &VirtualMachine,
795+
) -> PyResult<PyObjectRef> {
796+
format_spec.unpack(bytes, vm).map(|x| {
797+
if x.len() == 1 {
798+
x.fast_getitem(0)
799+
} else {
800+
x.into_object()
801+
}
802+
})
803+
}
804+
805+
pub fn unpack_bytes_seq_to_list(
806+
bytes: &[u8],
807+
format: &str,
808+
vm: &VirtualMachine,
809+
) -> PyResult<PyListRef> {
810+
let format_spec = PyMemoryView::parse_format(format, vm)?;
811+
let itemsize = format_spec.size();
812+
813+
if bytes.len() % itemsize != 0 {
814+
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
815+
}
816+
817+
let len = bytes.len() / itemsize;
818+
819+
let elements: Vec<PyObjectRef> = (0..len)
820+
.map(|i| format_unpack(&format_spec, &bytes[i..i + itemsize], vm))
821+
.try_collect()?;
822+
823+
Ok(PyList::from(elements).into_ref(vm))
824+
}

vm/src/stdlib/pystruct.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ pub(crate) mod _struct {
193193
Ok(PyTupleRef::with_elements(items, &vm.ctx))
194194
}
195195

196-
fn size(&self) -> usize {
196+
pub fn size(&self) -> usize {
197197
self.codes.iter().map(FormatCode::size).sum()
198198
}
199199
}

0 commit comments

Comments
 (0)