diff --git a/stdlib/src/array.rs b/stdlib/src/array.rs index 29a675e110..77273bda90 100644 --- a/stdlib/src/array.rs +++ b/stdlib/src/array.rs @@ -17,11 +17,13 @@ mod array { }, class_or_notimplemented, function::{ArgBytesLike, ArgIterable, OptionalArg}, - protocol::{BufferInternal, BufferOptions, PyBuffer, PyIterReturn, ResizeGuard}, + protocol::{ + BufferInternal, BufferOptions, PyBuffer, PyIterReturn, PyMappingMethods, ResizeGuard, + }, sliceable::{saturate_index, PySliceableSequence, PySliceableSequenceMut, SequenceIndex}, slots::{ - AsBuffer, Comparable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, - SlotIterator, + AsBuffer, AsMapping, Comparable, Iterable, IteratorIterable, PyComparisonOp, + SlotConstructor, SlotIterator, }, IdProtocol, IntoPyObject, IntoPyResult, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, @@ -665,7 +667,10 @@ mod array { } } - #[pyimpl(flags(BASETYPE), with(Comparable, AsBuffer, Iterable, SlotConstructor))] + #[pyimpl( + flags(BASETYPE), + with(Comparable, AsBuffer, AsMapping, Iterable, SlotConstructor) + )] impl PyArray { fn read(&self) -> PyRwLockReadGuard<'_, ArrayContentType> { self.array.read() @@ -1161,6 +1166,38 @@ mod array { } } + impl AsMapping for PyArray { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: Some(Self::ass_subscript), + }) + } + + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len()))? + } + + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| zelf.getitem(needle, vm))? + } + + fn ass_subscript( + zelf: PyObjectRef, + needle: PyObjectRef, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + match value { + Some(value) => { + Self::downcast(zelf, vm).map(|zelf| Self::setitem(zelf, needle, value, vm))? + } + None => Self::downcast(zelf, vm).map(|zelf| Self::delitem(zelf, needle, vm))?, + } + } + } + impl Iterable for PyArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyArrayIter { diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 2ab6091dcd..2d8e4c849f 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -15,11 +15,13 @@ use crate::{ ByteInnerTranslateOptions, DecodeArgs, PyBytesInner, }, function::{ArgBytesLike, ArgIterable, FuncArgs, OptionalArg, OptionalOption}, - protocol::{BufferInternal, BufferOptions, PyBuffer, PyIterReturn, ResizeGuard}, + protocol::{ + BufferInternal, BufferOptions, PyBuffer, PyIterReturn, PyMappingMethods, ResizeGuard, + }, sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex}, slots::{ - AsBuffer, Callable, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, - SlotIterator, Unhashable, + AsBuffer, AsMapping, Callable, Comparable, Hashable, Iterable, IteratorIterable, + PyComparisonOp, SlotIterator, Unhashable, }, utils::Either, IdProtocol, IntoPyObject, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, @@ -95,7 +97,10 @@ pub(crate) fn init(context: &PyContext) { PyByteArrayIterator::extend_class(context, &context.types.bytearray_iterator_type); } -#[pyimpl(flags(BASETYPE), with(Hashable, Comparable, AsBuffer, Iterable))] +#[pyimpl( + flags(BASETYPE), + with(Hashable, Comparable, AsBuffer, AsMapping, Iterable) +)] impl PyByteArray { #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -201,9 +206,9 @@ impl PyByteArray { } #[pymethod(magic)] - pub fn delitem(&self, needle: SequenceIndex, vm: &VirtualMachine) -> PyResult<()> { + pub fn delitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let elements = &mut self.try_resizable(vm)?.elements; - match needle { + match SequenceIndex::try_from_object_for(vm, needle, Self::NAME)? { SequenceIndex::Int(int) => { if let Some(idx) = elements.wrap_index(int) { elements.remove(idx); @@ -712,6 +717,41 @@ impl<'a> ResizeGuard<'a> for PyByteArray { } } +impl AsMapping for PyByteArray { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: Some(Self::ass_subscript), + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len()))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| zelf.getitem(needle, vm))? + } + + #[inline] + fn ass_subscript( + zelf: PyObjectRef, + needle: PyObjectRef, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + match value { + Some(value) => { + Self::downcast(zelf, vm).map(|zelf| Self::setitem(zelf, needle, value, vm)) + } + None => Self::downcast_ref(&zelf, vm).map(|zelf| zelf.delitem(needle, vm)), + }? + } +} + impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 2df6000179..f24d5e345e 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -7,10 +7,10 @@ use crate::{ }, common::hash::PyHash, function::{ArgBytesLike, ArgIterable, OptionalArg, OptionalOption}, - protocol::{BufferInternal, BufferOptions, PyBuffer, PyIterReturn}, + protocol::{BufferInternal, BufferOptions, PyBuffer, PyIterReturn, PyMappingMethods}, slots::{ - AsBuffer, Callable, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, - SlotConstructor, SlotIterator, + AsBuffer, AsMapping, Callable, Comparable, Hashable, Iterable, IteratorIterable, + PyComparisonOp, SlotConstructor, SlotIterator, }, utils::Either, IdProtocol, IntoPyObject, IntoPyResult, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, @@ -103,7 +103,7 @@ impl SlotConstructor for PyBytes { #[pyimpl( flags(BASETYPE), - with(Hashable, Comparable, AsBuffer, Iterable, SlotConstructor) + with(AsMapping, Hashable, Comparable, AsBuffer, Iterable, SlotConstructor) )] impl PyBytes { #[pymethod(magic)] @@ -540,6 +540,36 @@ impl BufferInternal for PyRef { fn retain(&self) {} } +impl AsMapping for PyBytes { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: None, + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len()))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| zelf.getitem(needle, vm))? + } + + #[cold] + fn ass_subscript( + zelf: PyObjectRef, + _needle: PyObjectRef, + _value: Option, + _vm: &VirtualMachine, + ) -> PyResult<()> { + unreachable!("ass_subscript not implemented for {}", zelf.class()) + } +} + impl Hashable for PyBytes { fn hash(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { Ok(zelf.inner.hash(vm)) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index c35b100c76..122831d0b4 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -4,15 +4,16 @@ use crate::{ common::ascii, dictdatatype::{self, DictKey}, function::{ArgIterable, FuncArgs, KwArgs, OptionalArg}, - protocol::PyIterReturn, + protocol::{PyIterReturn, PyMappingMethods}, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, Unhashable, + AsMapping, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, + Unhashable, }, vm::{ReprGuard, VirtualMachine}, IdProtocol, IntoPyObject, ItemProtocol, PyArithmeticValue::*, PyAttributes, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, - PyResult, PyValue, TryFromObject, TypeProtocol, + PyResult, PyValue, TypeProtocol, }; use crossbeam_utils::atomic::AtomicCell; use std::fmt; @@ -51,7 +52,7 @@ impl PyValue for PyDict { // Python dict methods: #[allow(clippy::len_without_is_empty)] -#[pyimpl(with(Hashable, Comparable, Iterable), flags(BASETYPE))] +#[pyimpl(with(AsMapping, Hashable, Comparable, Iterable), flags(BASETYPE))] impl PyDict { #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -410,6 +411,39 @@ impl PyDict { } } +impl AsMapping for PyDict { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: Some(Self::ass_subscript), + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len()))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast(zelf, vm).map(|zelf| Self::getitem(zelf, needle, vm))? + } + + #[inline] + fn ass_subscript( + zelf: PyObjectRef, + needle: PyObjectRef, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + Self::downcast_ref(&zelf, vm).map(|zelf| match value { + Some(value) => zelf.setitem(needle, value, vm), + None => zelf.delitem(needle, vm), + })? + } +} + impl Comparable for PyDict { fn cmp( zelf: &PyRef, @@ -651,7 +685,7 @@ macro_rules! dict_iterator { } impl $name { - fn new(dict: PyDictRef) -> Self { + pub fn new(dict: PyDictRef) -> Self { $name { dict } } } @@ -901,26 +935,3 @@ pub(crate) fn init(context: &PyContext) { PyDictItemIterator::extend_class(context, &context.types.dict_itemiterator_type); PyDictReverseItemIterator::extend_class(context, &context.types.dict_reverseitemiterator_type); } - -pub struct PyMapping { - dict: PyDictRef, -} - -impl TryFromObject for PyMapping { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let dict = vm.ctx.new_dict(); - PyDict::merge( - &dict.entries, - OptionalArg::Present(obj), - KwArgs::default(), - vm, - )?; - Ok(PyMapping { dict }) - } -} - -impl PyMapping { - pub fn into_dict(self) -> PyDictRef { - self.dict - } -} diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 1f0fdfad58..bee377e7aa 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -8,11 +8,12 @@ use crate::common::lock::{ }; use crate::{ function::{ArgIterable, FuncArgs, OptionalArg}, - protocol::PyIterReturn, + protocol::{PyIterReturn, PyMappingMethods}, sequence::{self, SimpleSeq}, sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex}, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, Unhashable, + AsMapping, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, + Unhashable, }, utils::Either, vm::{ReprGuard, VirtualMachine}, @@ -82,7 +83,7 @@ pub(crate) struct SortOptions { pub type PyListRef = PyRef; -#[pyimpl(with(Iterable, Hashable, Comparable), flags(BASETYPE))] +#[pyimpl(with(AsMapping, Iterable, Hashable, Comparable), flags(BASETYPE))] impl PyList { #[pymethod] pub(crate) fn append(&self, x: PyObjectRef) { @@ -353,8 +354,8 @@ impl PyList { } #[pymethod(magic)] - fn delitem(&self, subscript: SequenceIndex, vm: &VirtualMachine) -> PyResult<()> { - match subscript { + fn delitem(&self, subscript: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + match SequenceIndex::try_from_object_for(vm, subscript, Self::NAME)? { SequenceIndex::Int(index) => self.delindex(index, vm), SequenceIndex::Slice(slice) => self.delslice(slice, vm), } @@ -416,6 +417,39 @@ impl PyList { } } +impl AsMapping for PyList { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: Some(Self::ass_subscript), + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len()))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast(zelf, vm).map(|zelf| Self::getitem(zelf, needle, vm))? + } + + #[inline] + fn ass_subscript( + zelf: PyObjectRef, + needle: PyObjectRef, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + Self::downcast_ref(&zelf, vm).map(|zelf| match value { + Some(value) => zelf.setitem(needle, value, vm), + None => zelf.delitem(needle, vm), + })? + } +} + impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { diff --git a/vm/src/builtins/mappingproxy.rs b/vm/src/builtins/mappingproxy.rs index fcf7d90716..225fd14326 100644 --- a/vm/src/builtins/mappingproxy.rs +++ b/vm/src/builtins/mappingproxy.rs @@ -1,9 +1,10 @@ -use super::{PyDict, PyStrRef, PyTypeRef}; +use super::{PyDict, PyList, PyStrRef, PyTuple, PyTypeRef}; use crate::{ function::OptionalArg, - slots::{Iterable, SlotConstructor}, + protocol::{PyMapping, PyMappingMethods}, + slots::{AsMapping, Iterable, SlotConstructor}, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, - TryFromObject, VirtualMachine, + TryFromObject, TypeProtocol, VirtualMachine, }; #[pyclass(module = false, name = "mappingproxy")] @@ -36,14 +37,24 @@ impl SlotConstructor for PyMappingProxy { type Args = PyObjectRef; fn py_new(cls: PyTypeRef, mapping: Self::Args, vm: &VirtualMachine) -> PyResult { - Self { - mapping: MappingProxyInner::Dict(mapping), + if !PyMapping::check(&mapping) + || mapping.payload_if_subclass::(vm).is_some() + || mapping.payload_if_subclass::(vm).is_some() + { + Err(vm.new_type_error(format!( + "mappingproxy() argument must be a mapping, not {}", + mapping.class() + ))) + } else { + Self { + mapping: MappingProxyInner::Dict(mapping), + } + .into_pyresult_with_type(vm, cls) } - .into_pyresult_with_type(vm, cls) } } -#[pyimpl(with(Iterable, SlotConstructor))] +#[pyimpl(with(AsMapping, Iterable, SlotConstructor))] impl PyMappingProxy { fn get_inner(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult> { let opt = match &self.mapping { @@ -127,6 +138,37 @@ impl PyMappingProxy { } } } + +impl AsMapping for PyMappingProxy { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: None, + subscript: Some(Self::subscript), + ass_subscript: None, + }) + } + + #[inline] + fn length(zelf: PyObjectRef, _vm: &VirtualMachine) -> PyResult { + unreachable!("length not implemented for {}", zelf.class()) + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| zelf.getitem(needle, vm))? + } + + #[cold] + fn ass_subscript( + zelf: PyObjectRef, + _needle: PyObjectRef, + _value: Option, + _vm: &VirtualMachine, + ) -> PyResult<()> { + unreachable!("ass_subscript not implemented for {}", zelf.class()) + } +} + impl Iterable for PyMappingProxy { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { let obj = match &zelf.mapping { diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index eaf8539c77..1dd313dd44 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -8,13 +8,13 @@ use crate::common::{ use crate::{ bytesinner::bytes_to_hex, function::{FuncArgs, OptionalArg}, - protocol::{BufferInternal, BufferOptions, PyBuffer}, + protocol::{BufferInternal, BufferOptions, PyBuffer, PyMappingMethods}, sliceable::{convert_slice, wrap_index, SequenceIndex}, - slots::{AsBuffer, Comparable, Hashable, PyComparisonOp, SlotConstructor}, + slots::{AsBuffer, AsMapping, Comparable, Hashable, PyComparisonOp, SlotConstructor}, stdlib::pystruct::FormatSpec, utils::Either, - IdProtocol, IntoPyObject, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, - PyResult, PyValue, TryFromBorrowedObject, TryFromObject, TypeProtocol, VirtualMachine, + IdProtocol, IntoPyObject, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, + PyRef, PyResult, PyValue, TryFromBorrowedObject, TryFromObject, TypeProtocol, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; @@ -53,7 +53,7 @@ impl SlotConstructor for PyMemoryView { } } -#[pyimpl(with(Hashable, Comparable, AsBuffer, SlotConstructor))] +#[pyimpl(with(Hashable, Comparable, AsBuffer, AsMapping, SlotConstructor))] impl PyMemoryView { #[cfg(debug_assertions)] fn validate(self) -> Self { @@ -342,10 +342,10 @@ impl PyMemoryView { } #[pymethod(magic)] - fn getitem(zelf: PyRef, needle: SequenceIndex, vm: &VirtualMachine) -> PyResult { + fn getitem(zelf: PyRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { zelf.try_not_released(vm)?; - match needle { - SequenceIndex::Int(i) => Self::getitem_by_idx(zelf, i, vm), + match SequenceIndex::try_from_object_for(vm, needle, Self::NAME)? { + SequenceIndex::Int(index) => Self::getitem_by_idx(zelf, index, vm), SequenceIndex::Slice(slice) => Self::getitem_by_slice(zelf, slice, vm), } } @@ -459,7 +459,7 @@ impl PyMemoryView { #[pymethod(magic)] fn setitem( zelf: PyRef, - needle: SequenceIndex, + needle: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { @@ -467,8 +467,8 @@ impl PyMemoryView { if zelf.buffer.options.readonly { return Err(vm.new_type_error("cannot modify read-only memory".to_owned())); } - match needle { - SequenceIndex::Int(i) => Self::setitem_by_idx(zelf, i, value, vm), + match SequenceIndex::try_from_object_for(vm, needle, Self::NAME)? { + SequenceIndex::Int(index) => Self::setitem_by_idx(zelf, index, value, vm), SequenceIndex::Slice(slice) => Self::setitem_by_slice(zelf, slice, value, vm), } } @@ -734,6 +734,41 @@ impl BufferInternal for PyRef { fn retain(&self) {} } +impl AsMapping for PyMemoryView { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: Some(Self::ass_subscript), + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| zelf.len(vm))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast(zelf, vm).map(|zelf| Self::getitem(zelf, needle, vm))? + } + + #[inline] + fn ass_subscript( + zelf: PyObjectRef, + needle: PyObjectRef, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + match value { + Some(value) => { + Self::downcast(zelf, vm).map(|zelf| Self::setitem(zelf, needle, value, vm))? + } + None => Err(vm.new_type_error("cannot delete memory".to_owned())), + } + } +} + impl Comparable for PyMemoryView { fn cmp( zelf: &PyRef, diff --git a/vm/src/builtins/pytype.rs b/vm/src/builtins/pytype.rs index b9f67d096d..9db1530c8a 100644 --- a/vm/src/builtins/pytype.rs +++ b/vm/src/builtins/pytype.rs @@ -5,13 +5,14 @@ use super::{ use crate::common::{ascii, lock::PyRwLock}; use crate::{ function::{FuncArgs, KwArgs, OptionalArg}, - protocol::PyIterReturn, + protocol::{PyIterReturn, PyMappingMethods}, slots::{self, Callable, PyTypeFlags, PyTypeSlots, SlotGetattro, SlotSetattro}, utils::Either, IdProtocol, PyAttributes, PyClassImpl, PyContext, PyLease, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, }; use itertools::Itertools; +use num_traits::ToPrimitive; use std::collections::HashSet; use std::fmt; use std::ops::Deref; @@ -270,6 +271,54 @@ impl PyType { }; update_slot!(iternext, func); } + "__len__" | "__getitem__" | "__setitem__" | "__delitem__" => { + macro_rules! then_some_closure { + ($cond:expr, $closure:expr) => { + if $cond { + Some($closure) + } else { + None + } + }; + } + + let func: slots::MappingFunc = |zelf, _vm| { + Ok(PyMappingMethods { + length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| { + vm.call_special_method(zelf, "__len__", ()).map(|obj| { + obj.payload_if_subclass::(vm) + .map(|length_obj| { + length_obj.as_bigint().to_usize().ok_or_else(|| { + vm.new_value_error( + "__len__() should return >= 0".to_owned(), + ) + }) + }) + .unwrap() + })? + }), + subscript: then_some_closure!( + zelf.has_class_attr("__getitem__"), + |zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine| { + vm.call_special_method(zelf, "__getitem__", (needle,)) + } + ), + ass_subscript: then_some_closure!( + zelf.has_class_attr("__setitem__") | zelf.has_class_attr("__delitem__"), + |zelf, needle, value, vm| match value { + Some(value) => vm + .call_special_method(zelf, "__setitem__", (needle, value),) + .map(|_| Ok(()))?, + None => vm + .call_special_method(zelf, "__delitem__", (needle,)) + .map(|_| Ok(()))?, + } + ), + }) + }; + update_slot!(as_mapping, func); + // TODO: need to update sequence protocol too + } _ => {} } } diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index 4f6db0a594..6c9d0a0dbb 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -2,8 +2,10 @@ use super::{PyInt, PyIntRef, PySlice, PySliceRef, PyTypeRef}; use crate::common::hash::PyHash; use crate::{ function::{FuncArgs, OptionalArg}, - protocol::PyIterReturn, - slots::{Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator}, + protocol::{PyIterReturn, PyMappingMethods}, + slots::{ + AsMapping, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, + }, IdProtocol, IntoPyRef, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, }; @@ -98,7 +100,7 @@ impl PyRange { #[inline] pub fn is_empty(&self) -> bool { - self.length().is_zero() + self.compute_length().is_zero() } #[inline] @@ -116,7 +118,7 @@ impl PyRange { } if index.is_negative() { - let length = self.length(); + let length = self.compute_length(); let index: BigInt = &length + index; if index.is_negative() { return None; @@ -143,7 +145,7 @@ impl PyRange { } #[inline] - fn length(&self) -> BigInt { + fn compute_length(&self) -> BigInt { let start = self.start.as_bigint(); let stop = self.stop.as_bigint(); let step = self.step.as_bigint(); @@ -173,7 +175,7 @@ pub fn init(context: &PyContext) { PyRangeIterator::extend_class(context, &context.types.range_iterator_type); } -#[pyimpl(with(Hashable, Comparable, Iterable))] +#[pyimpl(with(AsMapping, Hashable, Comparable, Iterable))] impl PyRange { fn new(cls: PyTypeRef, stop: PyIntRef, vm: &VirtualMachine) -> PyResult> { PyRange { @@ -251,7 +253,7 @@ impl PyRange { #[pymethod(magic)] fn len(&self) -> BigInt { - self.length() + self.compute_length() } #[pymethod(magic)] @@ -331,11 +333,11 @@ impl PyRange { } #[pymethod(magic)] - fn getitem(&self, subscript: RangeIndex, vm: &VirtualMachine) -> PyResult { - match subscript { + fn getitem(&self, subscript: PyObjectRef, vm: &VirtualMachine) -> PyResult { + match RangeIndex::try_from_object(vm, subscript)? { RangeIndex::Slice(slice) => { let (mut substart, mut substop, mut substep) = - slice.inner_indices(&self.length(), vm)?; + slice.inner_indices(&self.compute_length(), vm)?; let range_step = &self.step; let range_start = &self.start; @@ -372,9 +374,39 @@ impl PyRange { } } +impl AsMapping for PyRange { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: None, + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len().to_usize().unwrap()))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| zelf.getitem(needle, vm))? + } + + #[inline] + fn ass_subscript( + zelf: PyObjectRef, + _needle: PyObjectRef, + _value: Option, + _vm: &VirtualMachine, + ) -> PyResult<()> { + unreachable!("ass_subscript not implemented for {}", zelf.class()) + } +} + impl Hashable for PyRange { fn hash(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let length = zelf.length(); + let length = zelf.compute_length(); let elements = if length.is_zero() { [vm.ctx.new_int(length), vm.ctx.none(), vm.ctx.none()] } else if length.is_one() { @@ -406,8 +438,8 @@ impl Comparable for PyRange { return Ok(true.into()); } let rhs = class_or_notimplemented!(Self, other); - let lhs_len = zelf.length(); - let eq = if lhs_len != rhs.length() { + let lhs_len = zelf.compute_length(); + let eq = if lhs_len != rhs.compute_length() { false } else if lhs_len.is_zero() { true diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 03c11d4e50..3f5e8e51df 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -6,12 +6,12 @@ use super::{ use crate::common::hash::PyHash; use crate::{ function::OptionalArg, - protocol::PyIterReturn, + protocol::{PyIterReturn, PyMappingMethods}, sequence::{self, SimpleSeq}, sliceable::PySliceableSequence, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, - SlotIterator, + AsMapping, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, + SlotConstructor, SlotIterator, }, utils::Either, vm::{ReprGuard, VirtualMachine}, @@ -111,7 +111,10 @@ impl SlotConstructor for PyTuple { } } -#[pyimpl(flags(BASETYPE), with(Hashable, Comparable, Iterable, SlotConstructor))] +#[pyimpl( + flags(BASETYPE), + with(AsMapping, Hashable, Comparable, Iterable, SlotConstructor) +)] impl PyTuple { /// Creating a new tuple with given boxed slice. /// NOTE: for usual case, you probably want to use PyTupleRef::with_elements. @@ -284,6 +287,36 @@ impl PyTuple { } } +impl AsMapping for PyTuple { + fn as_mapping(_zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(PyMappingMethods { + length: Some(Self::length), + subscript: Some(Self::subscript), + ass_subscript: None, + }) + } + + #[inline] + fn length(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast_ref(&zelf, vm).map(|zelf| Ok(zelf.len()))? + } + + #[inline] + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Self::downcast(zelf, vm).map(|zelf| Self::getitem(zelf, needle, vm))? + } + + #[inline] + fn ass_subscript( + zelf: PyObjectRef, + _needle: PyObjectRef, + _value: Option, + _vm: &VirtualMachine, + ) -> PyResult<()> { + unreachable!("ass_subscript not implemented for {}", zelf.class()) + } +} + impl Hashable for PyTuple { fn hash(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { crate::utils::hash_iter(zelf.elements.iter(), vm) diff --git a/vm/src/protocol/mapping.rs b/vm/src/protocol/mapping.rs new file mode 100644 index 0000000000..bf832cf1f7 --- /dev/null +++ b/vm/src/protocol/mapping.rs @@ -0,0 +1,135 @@ +use crate::{ + builtins::dict::{PyDictKeys, PyDictRef, PyDictValues}, + builtins::list::PyList, + vm::VirtualMachine, + IdProtocol, IntoPyObject, PyObjectRef, PyResult, TryFromObject, TypeProtocol, +}; +use std::borrow::Borrow; +use std::ops::Deref; + +// Mapping protocol +// https://docs.python.org/3/c-api/mapping.html +#[allow(clippy::type_complexity)] +#[derive(Default)] +pub struct PyMappingMethods { + pub length: Option PyResult>, + pub subscript: Option PyResult>, + pub ass_subscript: + Option, &VirtualMachine) -> PyResult<()>>, +} + +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyMapping(T) +where + T: Borrow; + +impl PyMapping { + pub fn into_object(self) -> PyObjectRef { + self.0 + } + + pub fn check(obj: &PyObjectRef) -> bool { + obj.class() + .mro_find_map(|x| x.slots.as_mapping.load()) + .is_some() + } + + pub fn methods(&self, vm: &VirtualMachine) -> PyMappingMethods { + let obj_cls = self.0.class(); + for cls in obj_cls.iter_mro() { + if let Some(f) = cls.slots.as_mapping.load() { + return f(&self.0, vm).unwrap(); + } + } + PyMappingMethods::default() + } +} + +impl PyMapping +where + T: Borrow, +{ + pub fn new(obj: T) -> Self { + Self(obj) + } + + pub fn keys(&self, vm: &VirtualMachine) -> PyResult { + if self.0.borrow().is(&vm.ctx.types.dict_type) { + Ok( + PyDictKeys::new(PyDictRef::try_from_object(vm, self.0.borrow().clone())?) + .into_pyobject(vm), + ) + } else { + Self::method_output_as_list(self.0.borrow(), "keys", vm) + } + } + + pub fn values(&self, vm: &VirtualMachine) -> PyResult { + if self.0.borrow().is(&vm.ctx.types.dict_type) { + Ok( + PyDictValues::new(PyDictRef::try_from_object(vm, self.0.borrow().clone())?) + .into_pyobject(vm), + ) + } else { + Self::method_output_as_list(self.0.borrow(), "values", vm) + } + } + + fn method_output_as_list( + obj: &PyObjectRef, + method_name: &str, + vm: &VirtualMachine, + ) -> PyResult { + let meth_output = vm.call_method(obj, method_name, ())?; + if meth_output.is(&vm.ctx.types.list_type) { + return Ok(meth_output); + } + + let iter = meth_output.clone().get_iter(vm).map_err(|_| { + vm.new_type_error(format!( + "{}.{}() returned a non-iterable (type {})", + obj.class(), + method_name, + meth_output.class() + )) + })?; + + Ok(PyList::from(vm.extract_elements(&iter)?).into_pyobject(vm)) + } +} + +impl Borrow for PyMapping +where + T: Borrow, +{ + fn borrow(&self) -> &PyObjectRef { + self.0.borrow() + } +} + +impl Deref for PyMapping +where + T: Borrow, +{ + type Target = PyObjectRef; + fn deref(&self) -> &Self::Target { + self.0.borrow() + } +} + +impl IntoPyObject for PyMapping { + fn into_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef { + self.into_object() + } +} + +impl TryFromObject for PyMapping { + fn try_from_object(vm: &VirtualMachine, mapping: PyObjectRef) -> PyResult { + if Self::check(&mapping) { + Ok(Self::new(mapping)) + } else { + Err(vm.new_type_error(format!("{} is not a mapping object", mapping.class()))) + } + } +} diff --git a/vm/src/protocol/mod.rs b/vm/src/protocol/mod.rs index 0c44ab7545..548eaaee12 100644 --- a/vm/src/protocol/mod.rs +++ b/vm/src/protocol/mod.rs @@ -1,5 +1,7 @@ mod buffer; mod iter; +mod mapping; pub use buffer::{BufferInternal, BufferOptions, PyBuffer, ResizeGuard}; pub use iter::{PyIter, PyIterReturn}; +pub use mapping::{PyMapping, PyMappingMethods}; diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 514a88823a..b51e5a000d 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -20,6 +20,7 @@ use crate::{ dictdatatype::Dict, exceptions, function::{IntoFuncArgs, IntoPyNativeFunc}, + protocol::PyMapping, slots::{PyTypeFlags, PyTypeSlots}, types::{create_type_with_slots, TypeZoo}, VirtualMachine, @@ -639,6 +640,12 @@ where T: IntoPyObject, { fn get_item(&self, key: T, vm: &VirtualMachine) -> PyResult { + if let Ok(mapping) = PyMapping::try_from_object(vm, self.clone()) { + if let Some(getitem) = mapping.methods(vm).subscript { + return getitem(self.clone(), key.into_pyobject(vm), vm); + } + } + match vm.get_special_method(self.clone(), "__getitem__")? { Ok(special_method) => return special_method.invoke((key,), vm), Err(obj) => { @@ -656,6 +663,12 @@ where } fn set_item(&self, key: T, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let Ok(mapping) = PyMapping::try_from_object(vm, self.clone()) { + if let Some(setitem) = mapping.methods(vm).ass_subscript { + return setitem(self.clone(), key.into_pyobject(vm), Some(value), vm); + } + } + vm.get_special_method(self.clone(), "__setitem__")? .map_err(|obj| { vm.new_type_error(format!( @@ -668,6 +681,12 @@ where } fn del_item(&self, key: T, vm: &VirtualMachine) -> PyResult<()> { + if let Ok(mapping) = PyMapping::try_from_object(vm, self.clone()) { + if let Some(setitem) = mapping.methods(vm).ass_subscript { + return setitem(self.clone(), key.into_pyobject(vm), None, vm); + } + } + vm.get_special_method(self.clone(), "__delitem__")? .map_err(|obj| { vm.new_type_error(format!( diff --git a/vm/src/slots.rs b/vm/src/slots.rs index 37382e6ecd..99c25264e3 100644 --- a/vm/src/slots.rs +++ b/vm/src/slots.rs @@ -2,7 +2,7 @@ use crate::builtins::{PyStrRef, PyTypeRef}; use crate::common::hash::PyHash; use crate::common::lock::PyRwLock; use crate::function::{FromArgs, FuncArgs, OptionalArg}; -use crate::protocol::{PyBuffer, PyIterReturn}; +use crate::protocol::{PyBuffer, PyIterReturn, PyMappingMethods}; use crate::utils::Either; use crate::VirtualMachine; use crate::{ @@ -70,6 +70,7 @@ pub(crate) type GetattroFunc = fn(PyObjectRef, PyStrRef, &VirtualMachine) -> PyR pub(crate) type SetattroFunc = fn(&PyObjectRef, PyStrRef, Option, &VirtualMachine) -> PyResult<()>; pub(crate) type BufferFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; +pub(crate) type MappingFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; pub(crate) type IterFunc = fn(PyObjectRef, &VirtualMachine) -> PyResult; pub(crate) type IterNextFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; @@ -85,7 +86,7 @@ pub struct PyTypeSlots { // Method suites for standard classes // tp_as_number // tp_as_sequence - // tp_as_mapping + pub as_mapping: AtomicCell>, // More standard operations (here for binary compatibility) pub hash: AtomicCell>, @@ -532,6 +533,50 @@ pub trait AsBuffer: PyValue { fn as_buffer(zelf: &PyRef, vm: &VirtualMachine) -> PyResult; } +#[pyimpl] +pub trait AsMapping: PyValue { + #[pyslot] + fn slot_as_mapping(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + let zelf = zelf + .downcast_ref() + .ok_or_else(|| vm.new_type_error("unexpected payload for as_mapping".to_owned()))?; + Self::as_mapping(zelf, vm) + } + + fn downcast(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + zelf.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "{} type is required, not {}", + Self::class(vm), + obj.class() + )) + }) + } + + fn downcast_ref<'a>(zelf: &'a PyObjectRef, vm: &VirtualMachine) -> PyResult<&'a PyRef> { + zelf.downcast_ref::().ok_or_else(|| { + vm.new_type_error(format!( + "{} type is required, not {}", + Self::class(vm), + zelf.class() + )) + }) + } + + fn as_mapping(zelf: &PyRef, vm: &VirtualMachine) -> PyResult; + + fn length(zelf: PyObjectRef, _vm: &VirtualMachine) -> PyResult; + + fn subscript(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult; + + fn ass_subscript( + zelf: PyObjectRef, + needle: PyObjectRef, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()>; +} + #[pyimpl] pub trait Iterable: PyValue { #[pyslot] diff --git a/vm/src/stdlib/posix.rs b/vm/src/stdlib/posix.rs index 58f10618e1..ae6251a5a7 100644 --- a/vm/src/stdlib/posix.rs +++ b/vm/src/stdlib/posix.rs @@ -1049,8 +1049,24 @@ pub mod module { } #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))] - fn envp_from_dict(dict: PyDictRef, vm: &VirtualMachine) -> PyResult> { - dict.into_iter() + fn envp_from_dict( + env: crate::protocol::PyMapping, + vm: &VirtualMachine, + ) -> PyResult> { + let keys = env.keys(vm)?; + let values = env.values(vm)?; + + let keys = PyListRef::try_from_object(vm, keys) + .map_err(|_| vm.new_type_error("env.keys() is not a list".to_owned()))? + .borrow_vec() + .to_vec(); + let values = PyListRef::try_from_object(vm, values) + .map_err(|_| vm.new_type_error("env.values() is not a list".to_owned()))? + .borrow_vec() + .to_vec(); + + keys.into_iter() + .zip(values.into_iter()) .map(|(k, v)| { let k = PyPathLike::try_from_object(vm, k)?.into_bytes(); let v = PyPathLike::try_from_object(vm, v)?.into_bytes(); @@ -1085,7 +1101,7 @@ pub mod module { #[pyarg(positional)] args: crate::function::ArgIterable, #[pyarg(positional)] - env: crate::builtins::dict::PyMapping, + env: crate::protocol::PyMapping, #[pyarg(named, default)] file_actions: Option>, #[pyarg(named, default)] @@ -1198,7 +1214,7 @@ pub mod module { .map(|s| s.as_ptr() as _) .chain(std::iter::once(std::ptr::null_mut())) .collect(); - let mut env = envp_from_dict(self.env.into_dict(), vm)?; + let mut env = envp_from_dict(self.env, vm)?; let envp: Vec<*mut libc::c_char> = env .iter_mut() .map(|s| s.as_ptr() as _) diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index 5ca3e20712..1ce91a8378 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -2,10 +2,11 @@ use super::os::errno_err; use crate::{ - builtins::{dict::PyMapping, PyDictRef, PyStrRef}, + builtins::{PyListRef, PyStrRef}, exceptions::IntoPyException, function::OptionalArg, - PyObjectRef, PyResult, PySequence, TryFromObject, VirtualMachine, + protocol::PyMapping, + ItemProtocol, PyObjectRef, PyResult, PySequence, TryFromObject, VirtualMachine, }; use std::ptr::{null, null_mut}; use winapi::shared::winerror; @@ -155,7 +156,7 @@ fn _winapi_CreateProcess( let mut env = args .env_mapping - .map(|m| getenvironment(m.into_dict(), vm)) + .map(|m| getenvironment(m, vm)) .transpose()?; let env = env.as_mut().map_or_else(null_mut, |v| v.as_mut_ptr()); @@ -216,9 +217,21 @@ fn _winapi_CreateProcess( )) } -fn getenvironment(env: PyDictRef, vm: &VirtualMachine) -> PyResult> { +fn getenvironment(env: PyMapping, vm: &VirtualMachine) -> PyResult> { + let keys = env.keys(vm)?; + let values = env.values(vm)?; + + let keys = PyListRef::try_from_object(vm, keys)?.borrow_vec().to_vec(); + let values = PyListRef::try_from_object(vm, values)? + .borrow_vec() + .to_vec(); + + if keys.len() != values.len() { + return Err(vm.new_runtime_error("environment changed size during iteration".to_owned())); + } + let mut out = widestring::WideString::new(); - for (k, v) in env { + for (k, v) in keys.into_iter().zip(values.into_iter()) { let k = PyStrRef::try_from_object(vm, k)?; let k = k.as_str(); let v = PyStrRef::try_from_object(vm, v)?; @@ -252,10 +265,11 @@ impl Drop for AttrList { fn getattributelist(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult> { >::try_from_object(vm, obj)? - .map(|d| { - let d = d.into_dict(); - let handlelist = d - .get_item_option("handle_list", vm)? + .map(|mapping| { + let handlelist = mapping + .into_object() + .get_item("handle_list", vm) + .ok() .and_then(|obj| { >>::try_from_object(vm, obj) .map(|s| match s { @@ -265,6 +279,7 @@ fn getattributelist(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult