diff --git a/tests/snippets/builtin_slice.py b/tests/snippets/builtin_slice.py index b12ad8021a..46592769ef 100644 --- a/tests/snippets/builtin_slice.py +++ b/tests/snippets/builtin_slice.py @@ -11,6 +11,7 @@ b = [1, 2] assert b[:] == [1, 2] +assert b[slice(None)] == [1, 2] assert b[: 2 ** 100] == [1, 2] assert b[-2 ** 100 :] == [1, 2] assert b[2 ** 100 :] == [] @@ -60,6 +61,12 @@ assert slice_c.stop == 5 assert slice_c.step == 2 +a = object() +slice_d = slice(a, "v", 1.0) +assert slice_d.start is a +assert slice_d.stop == "v" +assert slice_d.step == 1.0 + class SubScript(object): def __getitem__(self, item): @@ -74,6 +81,18 @@ def __setitem__(self, key, value): ss[:1] = 1 +class CustomIndex: + def __init__(self, x): + self.x = x + + def __index__(self): + return self.x + + +assert c[CustomIndex(1):CustomIndex(3)] == [1, 2] +assert d[CustomIndex(1):CustomIndex(3)] == "23" + + def test_all_slices(): """ test all possible slices except big number diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 2bf8215c1c..6685587133 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -2,8 +2,6 @@ use std::cell::RefCell; use std::fmt; use std::rc::Rc; -use num_bigint::BigInt; - use rustpython_parser::ast; use crate::builtins; @@ -12,7 +10,6 @@ use crate::function::PyFuncArgs; use crate::obj::objbool; use crate::obj::objcode::PyCodeRef; use crate::obj::objdict::{PyDict, PyDictRef}; -use crate::obj::objint::PyInt; use crate::obj::objiter; use crate::obj::objlist; use crate::obj::objslice::PySlice; @@ -435,26 +432,21 @@ impl Frame { } bytecode::Instruction::BuildSlice { size } => { assert!(*size == 2 || *size == 3); - let elements = self.pop_multiple(*size); - - let mut out: Vec> = elements - .into_iter() - .map(|x| { - if x.is(&vm.ctx.none()) { - None - } else if let Some(i) = x.payload::() { - Some(i.as_bigint().clone()) - } else { - panic!("Expect Int or None as BUILD_SLICE arguments") - } - }) - .collect(); - let start = out[0].take(); - let stop = out[1].take(); - let step = if out.len() == 3 { out[2].take() } else { None }; + let step = if *size == 3 { + Some(self.pop_value()) + } else { + None + }; + let stop = self.pop_value(); + let start = self.pop_value(); - let obj = PySlice { start, stop, step }.into_ref(vm); + let obj = PySlice { + start: Some(start), + stop, + step, + } + .into_ref(vm); self.push_value(obj.into_object()); Ok(None) } diff --git a/vm/src/function.rs b/vm/src/function.rs index 9ed237a508..a72ead8f4e 100644 --- a/vm/src/function.rs +++ b/vm/src/function.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::mem; use std::ops::RangeInclusive; use crate::obj::objtype::{isinstance, PyClassRef}; @@ -48,6 +49,12 @@ impl From<(&Args, &KwArgs)> for PyFuncArgs { } } +impl FromArgs for PyFuncArgs { + fn from_args(_vm: &VirtualMachine, args: &mut PyFuncArgs) -> Result { + Ok(mem::replace(args, Default::default())) + } +} + impl PyFuncArgs { pub fn new(mut args: Vec, kwarg_names: Vec) -> PyFuncArgs { let mut kwargs = vec![]; diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index a21bd9190c..bfc763d8cd 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -210,12 +210,12 @@ impl PyListRef { } fn setslice(self, slice: PySliceRef, sec: PyIterable, vm: &VirtualMachine) -> PyResult { - let step = slice.step.clone().unwrap_or_else(BigInt::one); + let step = slice.step_index(vm)?.unwrap_or_else(BigInt::one); if step.is_zero() { Err(vm.new_value_error("slice step cannot be zero".to_string())) } else if step.is_positive() { - let range = self.get_slice_range(&slice.start, &slice.stop); + let range = self.get_slice_range(&slice.start_index(vm)?, &slice.stop_index(vm)?); if range.start < range.end { match step.to_i32() { Some(1) => self._set_slice(range, sec, vm), @@ -237,14 +237,14 @@ impl PyListRef { } else { // calculate the range for the reverse slice, first the bounds needs to be made // exclusive around stop, the lower number - let start = &slice.start.as_ref().map(|x| { + let start = &slice.start_index(vm)?.as_ref().map(|x| { if *x == (-1).to_bigint().unwrap() { self.get_len() + BigInt::one() //.to_bigint().unwrap() } else { x + 1 } }); - let stop = &slice.stop.as_ref().map(|x| { + let stop = &slice.stop_index(vm)?.as_ref().map(|x| { if *x == (-1).to_bigint().unwrap() { self.get_len().to_bigint().unwrap() } else { @@ -552,9 +552,9 @@ impl PyListRef { } fn delslice(self, slice: PySliceRef, vm: &VirtualMachine) -> PyResult { - let start = &slice.start; - let stop = &slice.stop; - let step = slice.step.clone().unwrap_or_else(BigInt::one); + let start = slice.start_index(vm)?; + let stop = slice.stop_index(vm)?; + let step = slice.step_index(vm)?.unwrap_or_else(BigInt::one); if step.is_zero() { Err(vm.new_value_error("slice step cannot be zero".to_string())) diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 00b4e0e2ed..1a21a49ad3 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -273,7 +273,7 @@ impl PyRange { } } RangeIndex::Slice(slice) => { - let new_start = if let Some(int) = slice.start.as_ref() { + let new_start = if let Some(int) = slice.start_index(vm)? { if let Some(i) = self.get(int) { PyInt::new(i).into_ref(vm) } else { @@ -283,7 +283,7 @@ impl PyRange { self.start.clone() }; - let new_end = if let Some(int) = slice.stop.as_ref() { + let new_end = if let Some(int) = slice.stop_index(vm)? { if let Some(i) = self.get(int) { PyInt::new(i).into_ref(vm) } else { @@ -293,7 +293,7 @@ impl PyRange { self.stop.clone() }; - let new_step = if let Some(int) = slice.step.as_ref() { + let new_step = if let Some(int) = slice.step_index(vm)? { PyInt::new(int * self.step.as_bigint()).into_ref(vm) } else { self.step.clone() diff --git a/vm/src/obj/objsequence.rs b/vm/src/obj/objsequence.rs index 32c723a26f..5594ac8586 100644 --- a/vm/src/obj/objsequence.rs +++ b/vm/src/obj/objsequence.rs @@ -65,14 +65,15 @@ pub trait PySliceableSequence { where Self: Sized, { - // TODO: we could potentially avoid this copy and use slice - match slice.payload() { - Some(PySlice { start, stop, step }) => { - let step = step.clone().unwrap_or_else(BigInt::one); + match slice.clone().downcast::() { + Ok(slice) => { + let start = slice.start_index(vm)?; + let stop = slice.stop_index(vm)?; + let step = slice.step_index(vm)?.unwrap_or_else(BigInt::one); if step.is_zero() { Err(vm.new_value_error("slice step cannot be zero".to_string())) } else if step.is_positive() { - let range = self.get_slice_range(start, stop); + let range = self.get_slice_range(&start, &stop); if range.start < range.end { #[allow(clippy::range_plus_one)] match step.to_i32() { diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index 60d911a921..82e4cef8cd 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -1,18 +1,16 @@ -use num_bigint::BigInt; - -use crate::function::PyFuncArgs; -use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::function::{OptionalArg, PyFuncArgs}; +use crate::pyobject::{IdProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; use crate::vm::VirtualMachine; -use super::objint; -use crate::obj::objtype::PyClassRef; +use crate::obj::objint::PyInt; +use crate::obj::objtype::{class_has_attr, PyClassRef}; +use num_bigint::BigInt; #[derive(Debug)] pub struct PySlice { - // TODO: should be private - pub start: Option, - pub stop: Option, - pub step: Option, + pub start: Option, + pub stop: PyObjectRef, + pub step: Option, } impl PyValue for PySlice { @@ -23,52 +21,35 @@ impl PyValue for PySlice { pub type PySliceRef = PyRef; -fn slice_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - no_kwargs!(vm, args); - let (cls, start, stop, step): ( - &PyObjectRef, - Option<&PyObjectRef>, - Option<&PyObjectRef>, - Option<&PyObjectRef>, - ) = match args.args.len() { - 0 | 1 => Err(vm.new_type_error("slice() must have at least one arguments.".to_owned())), - 2 => { - arg_check!( - vm, - args, - required = [ - (cls, Some(vm.ctx.type_type())), - (stop, Some(vm.ctx.int_type())) - ] - ); - Ok((cls, None, Some(stop), None)) +fn slice_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { + let slice: PySlice = match args.args.len() { + 0 => { + return Err(vm.new_type_error("slice() must have at least one arguments.".to_owned())); + } + 1 => { + let stop = args.bind(vm)?; + PySlice { + start: None, + stop, + step: None, + } } _ => { - arg_check!( - vm, - args, - required = [ - (cls, Some(vm.ctx.type_type())), - (start, Some(vm.ctx.int_type())), - (stop, Some(vm.ctx.int_type())) - ], - optional = [(step, Some(vm.ctx.int_type()))] - ); - Ok((cls, Some(start), Some(stop), step)) + let (start, stop, step): (PyObjectRef, PyObjectRef, OptionalArg) = + args.bind(vm)?; + PySlice { + start: Some(start), + stop, + step: step.into_option(), + } } - }?; - PySlice { - start: start.map(|x| objint::get_value(x).clone()), - stop: stop.map(|x| objint::get_value(x).clone()), - step: step.map(|x| objint::get_value(x).clone()), - } - .into_ref_with_type(vm, cls.clone().downcast().unwrap()) - .map(PyRef::into_object) + }; + slice.into_ref_with_type(vm, cls) } -fn get_property_value(vm: &VirtualMachine, value: &Option) -> PyObjectRef { +fn get_property_value(vm: &VirtualMachine, value: &Option) -> PyObjectRef { if let Some(value) = value { - vm.ctx.new_int(value.clone()) + value.clone() } else { vm.get_none() } @@ -79,13 +60,57 @@ impl PySliceRef { get_property_value(vm, &self.start) } - fn stop(self, vm: &VirtualMachine) -> PyObjectRef { - get_property_value(vm, &self.stop) + fn stop(self, _vm: &VirtualMachine) -> PyObjectRef { + self.stop.clone() } fn step(self, vm: &VirtualMachine) -> PyObjectRef { get_property_value(vm, &self.step) } + + pub fn start_index(&self, vm: &VirtualMachine) -> PyResult> { + if let Some(obj) = &self.start { + to_index_value(vm, obj) + } else { + Ok(None) + } + } + + pub fn stop_index(&self, vm: &VirtualMachine) -> PyResult> { + to_index_value(vm, &self.stop) + } + + pub fn step_index(&self, vm: &VirtualMachine) -> PyResult> { + if let Some(obj) = &self.step { + to_index_value(vm, obj) + } else { + Ok(None) + } + } +} + +fn to_index_value(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + if obj.is(&vm.ctx.none) { + return Ok(None); + } + + if let Some(val) = obj.payload::() { + Ok(Some(val.as_bigint().clone())) + } else { + let cls = obj.class(); + if class_has_attr(&cls, "__index__") { + let index_result = vm.call_method(obj, "__index__", vec![])?; + if let Some(val) = index_result.payload::() { + Ok(Some(val.as_bigint().clone())) + } else { + Err(vm.new_type_error("__index__ method returned non integer".to_string())) + } + } else { + Err(vm.new_type_error( + "slice indices must be integers or None or have an __index__ method".to_string(), + )) + } + } } pub fn init(context: &PyContext) {