Skip to content

Commit 896590c

Browse files
committed
update memoryview eq
1 parent 680cce9 commit 896590c

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

vm/src/obj/objmemory.rs

+51-17
Original file line numberDiff line numberDiff line change
@@ -559,17 +559,7 @@ impl PyMemoryView {
559559
let bytes = &*zelf.obj_bytes();
560560
let elements: Vec<PyObjectRef> = (0..zelf.options.len)
561561
.map(|i| zelf.get_pos(i as isize).unwrap())
562-
.map(|i| {
563-
zelf.format_spec
564-
.unpack(&bytes[i..i + zelf.options.itemsize], vm)
565-
.map(|x| {
566-
if x.len() == 1 {
567-
x.fast_getitem(0)
568-
} else {
569-
x.into_object()
570-
}
571-
})
572-
})
562+
.map(|i| format_unpack(&zelf.format_spec, &bytes[i..i + zelf.options.itemsize], vm))
573563
.try_collect()?;
574564

575565
Ok(PyList::from(elements).into_ref(vm))
@@ -635,13 +625,15 @@ impl PyMemoryView {
635625
return Ok(false);
636626
}
637627

638-
let options_cmp = |a: &BufferOptions, b: &BufferOptions| -> bool {
639-
a.len == b.len && a.itemsize == b.itemsize
640-
};
641-
642628
let other = try_buffer_from_object(vm, other)?;
643629

644-
if !options_cmp(&zelf.options, &other.get_options()) {
630+
let a_options = &zelf.options;
631+
let b_options = &*other.get_options();
632+
633+
if a_options.len != b_options.len
634+
|| a_options.ndim != b_options.ndim
635+
|| a_options.shape != b_options.shape
636+
{
645637
return Ok(false);
646638
}
647639

@@ -670,7 +662,14 @@ impl PyMemoryView {
670662
}
671663
};
672664

673-
Ok(a == b)
665+
if a_options.format == b_options.format {
666+
Ok(a == b)
667+
} else {
668+
let a_list = unpack_bytes_seq_to_list(a, &a_options.format, vm)?;
669+
let b_list = unpack_bytes_seq_to_list(b, &b_options.format, vm)?;
670+
671+
Ok(vm.bool_eq(a_list.as_object(), b_list.as_object())?)
672+
}
674673
}
675674
}
676675

@@ -789,3 +788,38 @@ pub fn try_buffer_from_object(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResul
789788
obj_cls.name
790789
)))
791790
}
791+
792+
fn format_unpack(
793+
format_spec: &FormatSpec,
794+
bytes: &[u8],
795+
vm: &VirtualMachine,
796+
) -> PyResult<PyObjectRef> {
797+
format_spec.unpack(bytes, vm).map(|x| {
798+
if x.len() == 1 {
799+
x.fast_getitem(0)
800+
} else {
801+
x.into_object()
802+
}
803+
})
804+
}
805+
806+
pub fn unpack_bytes_seq_to_list(
807+
bytes: &[u8],
808+
format: &str,
809+
vm: &VirtualMachine,
810+
) -> PyResult<PyListRef> {
811+
let format_spec = PyMemoryView::parse_format(format, vm)?;
812+
let itemsize = format_spec.size();
813+
814+
if bytes.len() % itemsize != 0 {
815+
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
816+
}
817+
818+
let len = bytes.len() / itemsize;
819+
820+
let elements: Vec<PyObjectRef> = (0..len)
821+
.map(|i| format_unpack(&format_spec, &bytes[i..i + itemsize], vm))
822+
.try_collect()?;
823+
824+
Ok(PyList::from(elements).into_ref(vm))
825+
}

vm/src/stdlib/pystruct.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ pub(crate) mod _struct {
193193
Ok(PyTupleRef::with_elements(items, &vm.ctx))
194194
}
195195

196-
fn size(&self) -> usize {
196+
pub fn size(&self) -> usize {
197197
self.codes.iter().map(FormatCode::size).sum()
198198
}
199199
}

0 commit comments

Comments
 (0)