diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index db4ded48dd..468490eeaa 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -7,14 +7,13 @@ use num_traits::{One, Signed, ToPrimitive, Zero}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + Either, PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::vm::VirtualMachine; use super::objint::{self, PyInt, PyIntRef}; -use super::objslice::PySlice; -use super::objtype; -use super::objtype::PyClassRef; +use super::objslice::PySliceRef; +use super::objtype::{self, PyClassRef}; #[derive(Debug, Clone)] pub struct PyRange { @@ -169,7 +168,7 @@ pub fn init(context: &PyContext) { "__bool__" => context.new_rustfunc(range_bool), "__contains__" => context.new_rustfunc(range_contains), "__doc__" => context.new_str(range_doc.to_string()), - "__getitem__" => context.new_rustfunc(range_getitem), + "__getitem__" => context.new_rustfunc(PyRangeRef::getitem), "__iter__" => context.new_rustfunc(range_iter), "__len__" => context.new_rustfunc(range_len), "__new__" => context.new_rustfunc(range_new), @@ -212,6 +211,53 @@ impl PyRangeRef { } .into_ref_with_type(vm, cls) } + + fn getitem(self, subscript: Either, vm: &VirtualMachine) -> PyResult { + match subscript { + Either::A(index) => { + if let Some(value) = self.get(index.value.clone()) { + Ok(PyInt::new(value).into_ref(vm).into_object()) + } else { + Err(vm.new_index_error("range object index out of range".to_string())) + } + } + Either::B(slice) => { + let new_start = if let Some(int) = slice.start.clone() { + if let Some(i) = self.get(int) { + i + } else { + self.start.clone() + } + } else { + self.start.clone() + }; + + let new_end = if let Some(int) = slice.stop.clone() { + if let Some(i) = self.get(int) { + i + } else { + self.stop.clone() + } + } else { + self.stop.clone() + }; + + let new_step = if let Some(int) = slice.step.clone() { + int * self.step.clone() + } else { + self.step.clone() + }; + + Ok(PyRange { + start: new_start, + stop: new_end, + step: new_step, + } + .into_ref(vm) + .into_object()) + } + } + } } fn range_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -252,65 +298,6 @@ fn range_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn range_getitem(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.range_type())), (subscript, None)] - ); - - let range = get_value(zelf); - - if let Some(i) = subscript.payload::() { - if let Some(int) = range.get(i.value.clone()) { - Ok(vm.ctx.new_int(int)) - } else { - Err(vm.new_index_error("range object index out of range".to_string())) - } - } else if let Some(PySlice { - ref start, - ref stop, - ref step, - }) = subscript.payload() - { - let new_start = if let Some(int) = start { - if let Some(i) = range.get(int) { - i - } else { - range.start.clone() - } - } else { - range.start.clone() - }; - - let new_end = if let Some(int) = stop { - if let Some(i) = range.get(int) { - i - } else { - range.stop - } - } else { - range.stop - }; - - let new_step = if let Some(int) = step { - int * range.step - } else { - range.step - }; - - Ok(PyRange { - start: new_start, - stop: new_end, - step: new_step, - } - .into_ref(vm) - .into_object()) - } else { - Err(vm.new_type_error("range indices must be integer or slice".to_string())) - } -} - fn range_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index d68ae2255f..3d955b3f2c 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -1,7 +1,7 @@ use num_bigint::BigInt; use crate::function::PyFuncArgs; -use crate::pyobject::{PyContext, PyObjectRef, PyResult, PyValue, TypeProtocol}; +use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; use crate::vm::VirtualMachine; use super::objint; @@ -21,6 +21,8 @@ impl PyValue for PySlice { } } +pub type PySliceRef = PyRef; + fn slice_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { no_kwargs!(vm, args); let (cls, start, stop, step): ( diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 80eba36da3..19a1eee9f7 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -700,16 +700,23 @@ where } impl PyObject { - pub fn downcast(self: Rc) -> Option> { + /// Attempt to downcast this reference to a subclass. + /// + /// If the downcast fails, the original ref is returned in as `Err` so + /// another downcast can be attempted without unnecessary cloning. + /// + /// Note: The returned `Result` is _not_ a `PyResult`, even though the + /// types are compatible. + pub fn downcast(self: Rc) -> Result, PyObjectRef> { if self.payload_is::() { - Some({ + Ok({ PyRef { obj: self, _payload: PhantomData, } }) } else { - None + Err(self) } } } @@ -1196,6 +1203,52 @@ impl PyObjectPayload for T { } } +pub enum Either { + A(A), + B(B), +} + +/// This allows a builtin method to accept arguments that may be one of two +/// types, raising a `TypeError` if it is neither. +/// +/// # Example +/// +/// ``` +/// use rustpython_vm::VirtualMachine; +/// use rustpython_vm::obj::{objstr::PyStringRef, objint::PyIntRef}; +/// use rustpython_vm::pyobject::Either; +/// +/// fn do_something(arg: Either, vm: &VirtualMachine) { +/// match arg { +/// Either::A(int)=> { +/// // do something with int +/// } +/// Either::B(string) => { +/// // do something with string +/// } +/// } +/// } +/// ``` +impl TryFromObject for Either, PyRef> +where + A: PyValue, + B: PyValue, +{ + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + obj.downcast::() + .map(Either::A) + .or_else(|obj| obj.clone().downcast::().map(Either::B)) + .map_err(|obj| { + vm.new_type_error(format!( + "must be {} or {}, not {}", + A::class(vm), + B::class(vm), + obj.type_pyref() + )) + }) + } +} + #[cfg(test)] mod tests { use super::*;