From 54bcb085181ef15e4f7591be5c2adbea37b0e474 Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Wed, 6 Feb 2019 15:36:55 -0800 Subject: [PATCH] Add range.{__bool__, __contains__} - Also fix range.index for negative steps --- tests/snippets/builtin_range.py | 23 ++++++++++++ vm/src/obj/objrange.rs | 65 ++++++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/tests/snippets/builtin_range.py b/tests/snippets/builtin_range.py index 5e58ae069e..c822ce7e75 100644 --- a/tests/snippets/builtin_range.py +++ b/tests/snippets/builtin_range.py @@ -22,6 +22,7 @@ def assert_raises(expr, exc_type): assert range(10).index(6) == 6 assert range(4, 10).index(6) == 2 assert range(4, 10, 2).index(6) == 1 +assert range(10, 4, -2).index(8) == 1 # index raises value error on out of bounds assert_raises(lambda _: range(10).index(-1), ValueError) @@ -29,3 +30,25 @@ def assert_raises(expr, exc_type): # index raises value error if out of step assert_raises(lambda _: range(4, 10, 2).index(5), ValueError) + +# index raises value error if needle is not an int +assert_raises(lambda _: range(10).index('foo'), ValueError) + +# __bool__ +assert range(1).__bool__() +assert range(1, 2).__bool__() + +assert not range(0).__bool__() +assert not range(1, 1).__bool__() + +# __contains__ +assert range(10).__contains__(6) +assert range(4, 10).__contains__(6) +assert range(4, 10, 2).__contains__(6) +assert range(10, 4, -2).__contains__(10) +assert range(10, 4, -2).__contains__(8) + +assert not range(10).__contains__(-1) +assert not range(10, 4, -2).__contains__(9) +assert not range(10, 4, -2).__contains__(4) +assert not range(10).__contains__('foo') diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 605dabf849..cd897ff347 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -4,7 +4,7 @@ use super::super::pyobject::{ use super::super::vm::VirtualMachine; use super::objint; use super::objtype; -use num_bigint::BigInt; +use num_bigint::{BigInt, Sign}; use num_integer::Integer; use num_traits::{One, Signed, ToPrimitive, Zero}; @@ -27,16 +27,29 @@ impl RangeType { } #[inline] - pub fn index_of(&self, value: &BigInt) -> Option { - if value < &self.start || value >= &self.end { - return None; + fn offset(&self, value: &BigInt) -> Option { + match self.step.sign() { + Sign::Plus if value >= &self.start && value < &self.end => Some(value - &self.start), + Sign::Minus if value <= &self.start && value > &self.end => Some(&self.start - value), + _ => None, } + } - let offset = value - &self.start; - if offset.is_multiple_of(&self.step) { - Some(offset / &self.step) - } else { - None + #[inline] + pub fn contains(&self, value: &BigInt) -> bool { + match self.offset(value) { + Some(ref offset) => offset.is_multiple_of(&self.step), + None => false, + } + } + + #[inline] + pub fn index_of(&self, value: &BigInt) -> Option { + match self.offset(value) { + Some(ref offset) if offset.is_multiple_of(&self.step) => { + Some((offset / &self.step).abs()) + } + Some(_) | None => None, } } @@ -75,6 +88,12 @@ pub fn init(context: &PyContext) { "__getitem__", context.new_rustfunc(range_getitem), ); + context.set_attr(&range_type, "__bool__", context.new_rustfunc(range_bool)); + context.set_attr( + &range_type, + "__contains__", + context.new_rustfunc(range_contains), + ); context.set_attr(&range_type, "index", context.new_rustfunc(range_index)); } @@ -205,6 +224,34 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn range_bool(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); + + let len = match zelf.borrow().payload { + PyObjectPayload::Range { ref range } => range.len(), + _ => unreachable!(), + }; + + Ok(vm.ctx.new_bool(len > 0)) +} + +fn range_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, Some(vm.ctx.range_type())), (needle, None)] + ); + + if let PyObjectPayload::Range { ref range } = zelf.borrow().payload { + Ok(vm.ctx.new_bool(match needle.borrow().payload { + PyObjectPayload::Integer { ref value } => range.contains(value), + _ => false, + })) + } else { + unreachable!() + } +} + fn range_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm,