Skip to content

Commit 88f5466

Browse files
committed
update memoryview eq
1 parent 6058d65 commit 88f5466

File tree

2 files changed

+79
-35
lines changed

2 files changed

+79
-35
lines changed

vm/src/obj/objmemory.rs

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -558,17 +558,7 @@ impl PyMemoryView {
558558
let bytes = &*zelf.obj_bytes();
559559
let elements: Vec<PyObjectRef> = (0..zelf.options.len)
560560
.map(|i| zelf.get_pos(i as isize).unwrap())
561-
.map(|i| {
562-
zelf.format_spec
563-
.unpack(&bytes[i..i + zelf.options.itemsize], vm)
564-
.map(|x| {
565-
if x.len() == 1 {
566-
x.fast_getitem(0)
567-
} else {
568-
x.into_object()
569-
}
570-
})
571-
})
561+
.map(|i| format_unpack(&zelf.format_spec, &bytes[i..i + zelf.options.itemsize], vm))
572562
.try_collect()?;
573563

574564
Ok(PyList::from(elements).into_ref(vm))
@@ -633,33 +623,52 @@ impl PyMemoryView {
633623
if zelf.released.load() {
634624
return Ok(false);
635625
}
636-
let options_cmp = |a: &BufferOptions, b: &BufferOptions| -> bool {
637-
a.len == b.len && a.itemsize == b.itemsize
626+
627+
let other = try_buffer_from_object(vm, other)?;
628+
629+
let a_options = &zelf.options;
630+
let b_options = &*other.get_options();
631+
632+
if a_options.len != b_options.len
633+
|| a_options.ndim != b_options.ndim
634+
|| a_options.shape != b_options.shape
635+
{
636+
return Ok(false);
637+
}
638+
639+
let a_guard;
640+
let a_vec;
641+
let a = match zelf.as_contiguous() {
642+
Some(bytes) => {
643+
a_guard = bytes;
644+
&*a_guard
645+
}
646+
None => {
647+
a_vec = zelf.to_contiguous();
648+
a_vec.as_slice()
649+
}
638650
};
639-
// TODO: fast pass for contiguous buffer
640-
match other.clone().downcast::<PyMemoryView>() {
641-
Ok(other) => {
642-
if options_cmp(&zelf.options, &other.options) {
643-
let a = Self::tolist(zelf.clone(), vm)?;
644-
let b = Self::tolist(other, vm)?;
645-
if vm.bool_eq(a.as_object(), b.as_object())? {
646-
return Ok(true);
647-
}
648-
}
651+
let b_guard;
652+
let b_vec;
653+
let b = match other.as_contiguous() {
654+
Some(bytes) => {
655+
b_guard = bytes;
656+
&*b_guard
649657
}
650-
Err(other) => {
651-
if let Ok(buffer) = try_buffer_from_object(vm, &other) {
652-
let options = buffer.get_options();
653-
// FIXME
654-
if options_cmp(&zelf.options, &options)
655-
&& (**(Self::tobytes(zelf.clone(), vm)?) == *buffer.obj_bytes())
656-
{
657-
return Ok(true);
658-
}
659-
}
658+
None => {
659+
b_vec = other.to_contiguous();
660+
b_vec.as_slice()
660661
}
662+
};
663+
664+
if a_options.format == b_options.format {
665+
Ok(a == b)
666+
} else {
667+
let a_list = unpack_bytes_seq_to_list(a, &a_options.format, vm)?;
668+
let b_list = unpack_bytes_seq_to_list(b, &b_options.format, vm)?;
669+
670+
Ok(vm.bool_eq(a_list.as_object(), b_list.as_object())?)
661671
}
662-
Ok(false)
663672
}
664673
}
665674

@@ -778,3 +787,38 @@ pub fn try_buffer_from_object(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResul
778787
obj_cls.name
779788
)))
780789
}
790+
791+
fn format_unpack(
792+
format_spec: &FormatSpec,
793+
bytes: &[u8],
794+
vm: &VirtualMachine,
795+
) -> PyResult<PyObjectRef> {
796+
format_spec.unpack(bytes, vm).map(|x| {
797+
if x.len() == 1 {
798+
x.fast_getitem(0)
799+
} else {
800+
x.into_object()
801+
}
802+
})
803+
}
804+
805+
pub fn unpack_bytes_seq_to_list(
806+
bytes: &[u8],
807+
format: &str,
808+
vm: &VirtualMachine,
809+
) -> PyResult<PyListRef> {
810+
let format_spec = PyMemoryView::parse_format(format, vm)?;
811+
let itemsize = format_spec.size();
812+
813+
if bytes.len() % itemsize != 0 {
814+
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
815+
}
816+
817+
let len = bytes.len() / itemsize;
818+
819+
let elements: Vec<PyObjectRef> = (0..len)
820+
.map(|i| format_unpack(&format_spec, &bytes[i..i + itemsize], vm))
821+
.try_collect()?;
822+
823+
Ok(PyList::from(elements).into_ref(vm))
824+
}

vm/src/stdlib/pystruct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ pub(crate) mod _struct {
193193
Ok(PyTupleRef::with_elements(items, &vm.ctx))
194194
}
195195

196-
fn size(&self) -> usize {
196+
pub fn size(&self) -> usize {
197197
self.codes.iter().map(FormatCode::size).sum()
198198
}
199199
}

0 commit comments

Comments
 (0)