From a5558e0e32971c8b6c2a5754250e9a81152e31fa Mon Sep 17 00:00:00 2001 From: Joey Date: Sat, 23 Mar 2019 15:05:12 -0700 Subject: [PATCH 1/4] Introduce Either extractor and convert range.__getitem__ --- vm/src/obj/objrange.rs | 115 ++++++++++++++++++----------------------- vm/src/obj/objslice.rs | 4 +- vm/src/pyobject.rs | 44 ++++++++++++++++ 3 files changed, 98 insertions(+), 65 deletions(-) diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index db4ded48dd..8d526227a3 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -7,14 +7,13 @@ use num_traits::{One, Signed, ToPrimitive, Zero}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + Either2, PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::vm::VirtualMachine; use super::objint::{self, PyInt, PyIntRef}; -use super::objslice::PySlice; -use super::objtype; -use super::objtype::PyClassRef; +use super::objslice::PySliceRef; +use super::objtype::{self, PyClassRef}; #[derive(Debug, Clone)] pub struct PyRange { @@ -169,7 +168,7 @@ pub fn init(context: &PyContext) { "__bool__" => context.new_rustfunc(range_bool), "__contains__" => context.new_rustfunc(range_contains), "__doc__" => context.new_str(range_doc.to_string()), - "__getitem__" => context.new_rustfunc(range_getitem), + "__getitem__" => context.new_rustfunc(PyRangeRef::getitem), "__iter__" => context.new_rustfunc(range_iter), "__len__" => context.new_rustfunc(range_len), "__new__" => context.new_rustfunc(range_new), @@ -212,6 +211,53 @@ impl PyRangeRef { } .into_ref_with_type(vm, cls) } + + fn getitem(self, subscript: Either2, vm: &VirtualMachine) -> PyResult { + match subscript { + Either2::A(index) => { + if let Some(value) = self.get(index.value.clone()) { + Ok(PyInt::new(value).into_ref(vm).into_object()) + } else { + Err(vm.new_index_error("range object index out of range".to_string())) + } + } + Either2::B(slice) => { + let new_start = if let Some(int) = slice.start.clone() { + if let Some(i) = self.get(int) { + i + } else { + self.start.clone() + } + } else { + self.start.clone() + }; + + let new_end = if let Some(int) = slice.stop.clone() { + if let Some(i) = self.get(int) { + i + } else { + self.stop.clone() + } + } else { + self.stop.clone() + }; + + let new_step = if let Some(int) = slice.step.clone() { + int * self.step.clone() + } else { + self.step.clone() + }; + + Ok(PyRange { + start: new_start, + stop: new_end, + step: new_step, + } + .into_ref(vm) + .into_object()) + } + } + } } fn range_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -252,65 +298,6 @@ fn range_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn range_getitem(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.range_type())), (subscript, None)] - ); - - let range = get_value(zelf); - - if let Some(i) = subscript.payload::() { - if let Some(int) = range.get(i.value.clone()) { - Ok(vm.ctx.new_int(int)) - } else { - Err(vm.new_index_error("range object index out of range".to_string())) - } - } else if let Some(PySlice { - ref start, - ref stop, - ref step, - }) = subscript.payload() - { - let new_start = if let Some(int) = start { - if let Some(i) = range.get(int) { - i - } else { - range.start.clone() - } - } else { - range.start.clone() - }; - - let new_end = if let Some(int) = stop { - if let Some(i) = range.get(int) { - i - } else { - range.stop - } - } else { - range.stop - }; - - let new_step = if let Some(int) = step { - int * range.step - } else { - range.step - }; - - Ok(PyRange { - start: new_start, - stop: new_end, - step: new_step, - } - .into_ref(vm) - .into_object()) - } else { - Err(vm.new_type_error("range indices must be integer or slice".to_string())) - } -} - fn range_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index 4a7124fdfd..d171b2527b 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -1,7 +1,7 @@ use num_bigint::BigInt; use crate::function::PyFuncArgs; -use crate::pyobject::{PyContext, PyObject, PyObjectRef, PyResult, PyValue, TypeProtocol}; +use crate::pyobject::{PyContext, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; use crate::vm::VirtualMachine; use super::objint; @@ -21,6 +21,8 @@ impl PyValue for PySlice { } } +pub type PySliceRef = PyRef; + fn slice_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { no_kwargs!(vm, args); let (cls, start, stop, step): ( diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 41a1ca6f34..a90a4ebc08 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1218,6 +1218,50 @@ impl PyObjectPayload for T { } } +pub enum Either2 { + A(A), + B(B), +} + +/// This allows a builtin method to accept arguments that may be one of two +/// types, raising a `TypeError` if it is neither. +/// +/// # Example +/// +/// ``` +/// fn do_something(arg: Either2, vm: &VirtualMachine) { +/// match arg { +/// Either2::A(int)=> { +/// // do something with int +/// } +/// Either2::B(string) => { +/// // do something with string +/// } +/// } +/// } +/// ``` +impl TryFromObject for Either2, PyRef> +where + A: PyValue, + B: PyValue, +{ + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + // TODO: downcast could probably be reworked a bit to make these clones not necessary + obj.clone() + .downcast::() + .map(Either2::A) + .or_else(|| obj.clone().downcast::().map(Either2::B)) + .ok_or_else(|| { + vm.new_type_error(format!( + "must be {} or {}, not {}", + A::class(vm), + B::class(vm), + obj.type_pyref() + )) + }) + } +} + #[cfg(test)] mod tests { use super::*; From b6f1ecdb4be116d6242ed05541725adedf9a175c Mon Sep 17 00:00:00 2001 From: Joey Date: Sat, 23 Mar 2019 15:49:31 -0700 Subject: [PATCH 2/4] Fix example --- vm/src/pyobject.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index a90a4ebc08..40a2d56e25 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1229,6 +1229,10 @@ pub enum Either2 { /// # Example /// /// ``` +/// use rustpython_vm::VirtualMachine; +/// use rustpython_vm::obj::{objstr::PyStringRef, objint::PyIntRef}; +/// use rustpython_vm::pyobject::Either2; +/// /// fn do_something(arg: Either2, vm: &VirtualMachine) { /// match arg { /// Either2::A(int)=> { From 3c15d892c5b99f903c5389dc1e5beb6e76b488c3 Mon Sep 17 00:00:00 2001 From: Joey Date: Sat, 23 Mar 2019 15:57:06 -0700 Subject: [PATCH 3/4] Avoid some cloning --- vm/src/pyobject.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index aec210e7a8..befa798926 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -700,16 +700,23 @@ where } impl PyObject { - pub fn downcast(self: Rc) -> Option> { + /// Attempt to downcast this reference to a subclass. + /// + /// If the downcast fails, the original ref is returned in as `Err` so + /// another downcast can be attempted without unnecessary cloning. + /// + /// Note: The returned `Result` is _not_ a `PyResult`, even though the + /// types are compatible. + pub fn downcast(self: Rc) -> Result, PyObjectRef> { if self.payload_is::() { - Some({ + Ok({ PyRef { obj: self, _payload: PhantomData, } }) } else { - None + Err(self) } } } @@ -1228,12 +1235,10 @@ where B: PyValue, { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - // TODO: downcast could probably be reworked a bit to make these clones not necessary - obj.clone() - .downcast::() + obj.downcast::() .map(Either2::A) - .or_else(|| obj.clone().downcast::().map(Either2::B)) - .ok_or_else(|| { + .or_else(|obj| obj.clone().downcast::().map(Either2::B)) + .map_err(|obj| { vm.new_type_error(format!( "must be {} or {}, not {}", A::class(vm), From 84d47d21cf74aec990dbb5e4ddc4fa2f449d1298 Mon Sep 17 00:00:00 2001 From: Joey Date: Sat, 23 Mar 2019 17:51:12 -0700 Subject: [PATCH 4/4] Rename to just Either --- vm/src/obj/objrange.rs | 8 ++++---- vm/src/pyobject.rs | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 8d526227a3..468490eeaa 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -7,7 +7,7 @@ use num_traits::{One, Signed, ToPrimitive, Zero}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - Either2, PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + Either, PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -212,16 +212,16 @@ impl PyRangeRef { .into_ref_with_type(vm, cls) } - fn getitem(self, subscript: Either2, vm: &VirtualMachine) -> PyResult { + fn getitem(self, subscript: Either, vm: &VirtualMachine) -> PyResult { match subscript { - Either2::A(index) => { + Either::A(index) => { if let Some(value) = self.get(index.value.clone()) { Ok(PyInt::new(value).into_ref(vm).into_object()) } else { Err(vm.new_index_error("range object index out of range".to_string())) } } - Either2::B(slice) => { + Either::B(slice) => { let new_start = if let Some(int) = slice.start.clone() { if let Some(i) = self.get(int) { i diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index befa798926..19a1eee9f7 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1203,7 +1203,7 @@ impl PyObjectPayload for T { } } -pub enum Either2 { +pub enum Either { A(A), B(B), } @@ -1216,28 +1216,28 @@ pub enum Either2 { /// ``` /// use rustpython_vm::VirtualMachine; /// use rustpython_vm::obj::{objstr::PyStringRef, objint::PyIntRef}; -/// use rustpython_vm::pyobject::Either2; +/// use rustpython_vm::pyobject::Either; /// -/// fn do_something(arg: Either2, vm: &VirtualMachine) { +/// fn do_something(arg: Either, vm: &VirtualMachine) { /// match arg { -/// Either2::A(int)=> { +/// Either::A(int)=> { /// // do something with int /// } -/// Either2::B(string) => { +/// Either::B(string) => { /// // do something with string /// } /// } /// } /// ``` -impl TryFromObject for Either2, PyRef> +impl TryFromObject for Either, PyRef> where A: PyValue, B: PyValue, { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { obj.downcast::() - .map(Either2::A) - .or_else(|obj| obj.clone().downcast::().map(Either2::B)) + .map(Either::A) + .or_else(|obj| obj.clone().downcast::().map(Either::B)) .map_err(|obj| { vm.new_type_error(format!( "must be {} or {}, not {}",