Skip to content

Commit a4b9925

Browse files
Merge pull request #382 from OddCoincidence/range-bool-contains
Add range.{__bool__, __contains__}
2 parents 3dc2ab9 + 54bcb08 commit a4b9925

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

tests/snippets/builtin_range.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,33 @@ def assert_raises(expr, exc_type):
2222
assert range(10).index(6) == 6
2323
assert range(4, 10).index(6) == 2
2424
assert range(4, 10, 2).index(6) == 1
25+
assert range(10, 4, -2).index(8) == 1
2526

2627
# index raises value error on out of bounds
2728
assert_raises(lambda _: range(10).index(-1), ValueError)
2829
assert_raises(lambda _: range(10).index(10), ValueError)
2930

3031
# index raises value error if out of step
3132
assert_raises(lambda _: range(4, 10, 2).index(5), ValueError)
33+
34+
# index raises value error if needle is not an int
35+
assert_raises(lambda _: range(10).index('foo'), ValueError)
36+
37+
# __bool__
38+
assert range(1).__bool__()
39+
assert range(1, 2).__bool__()
40+
41+
assert not range(0).__bool__()
42+
assert not range(1, 1).__bool__()
43+
44+
# __contains__
45+
assert range(10).__contains__(6)
46+
assert range(4, 10).__contains__(6)
47+
assert range(4, 10, 2).__contains__(6)
48+
assert range(10, 4, -2).__contains__(10)
49+
assert range(10, 4, -2).__contains__(8)
50+
51+
assert not range(10).__contains__(-1)
52+
assert not range(10, 4, -2).__contains__(9)
53+
assert not range(10, 4, -2).__contains__(4)
54+
assert not range(10).__contains__('foo')

vm/src/obj/objrange.rs

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use super::super::pyobject::{
44
use super::super::vm::VirtualMachine;
55
use super::objint;
66
use super::objtype;
7-
use num_bigint::BigInt;
7+
use num_bigint::{BigInt, Sign};
88
use num_integer::Integer;
99
use num_traits::{One, Signed, ToPrimitive, Zero};
1010

@@ -27,16 +27,29 @@ impl RangeType {
2727
}
2828

2929
#[inline]
30-
pub fn index_of(&self, value: &BigInt) -> Option<BigInt> {
31-
if value < &self.start || value >= &self.end {
32-
return None;
30+
fn offset(&self, value: &BigInt) -> Option<BigInt> {
31+
match self.step.sign() {
32+
Sign::Plus if value >= &self.start && value < &self.end => Some(value - &self.start),
33+
Sign::Minus if value <= &self.start && value > &self.end => Some(&self.start - value),
34+
_ => None,
3335
}
36+
}
3437

35-
let offset = value - &self.start;
36-
if offset.is_multiple_of(&self.step) {
37-
Some(offset / &self.step)
38-
} else {
39-
None
38+
#[inline]
39+
pub fn contains(&self, value: &BigInt) -> bool {
40+
match self.offset(value) {
41+
Some(ref offset) => offset.is_multiple_of(&self.step),
42+
None => false,
43+
}
44+
}
45+
46+
#[inline]
47+
pub fn index_of(&self, value: &BigInt) -> Option<BigInt> {
48+
match self.offset(value) {
49+
Some(ref offset) if offset.is_multiple_of(&self.step) => {
50+
Some((offset / &self.step).abs())
51+
}
52+
Some(_) | None => None,
4053
}
4154
}
4255

@@ -75,6 +88,12 @@ pub fn init(context: &PyContext) {
7588
"__getitem__",
7689
context.new_rustfunc(range_getitem),
7790
);
91+
context.set_attr(&range_type, "__bool__", context.new_rustfunc(range_bool));
92+
context.set_attr(
93+
&range_type,
94+
"__contains__",
95+
context.new_rustfunc(range_contains),
96+
);
7897
context.set_attr(&range_type, "index", context.new_rustfunc(range_index));
7998
}
8099

@@ -205,6 +224,34 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
205224
}
206225
}
207226

227+
fn range_bool(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
228+
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]);
229+
230+
let len = match zelf.borrow().payload {
231+
PyObjectPayload::Range { ref range } => range.len(),
232+
_ => unreachable!(),
233+
};
234+
235+
Ok(vm.ctx.new_bool(len > 0))
236+
}
237+
238+
fn range_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
239+
arg_check!(
240+
vm,
241+
args,
242+
required = [(zelf, Some(vm.ctx.range_type())), (needle, None)]
243+
);
244+
245+
if let PyObjectPayload::Range { ref range } = zelf.borrow().payload {
246+
Ok(vm.ctx.new_bool(match needle.borrow().payload {
247+
PyObjectPayload::Integer { ref value } => range.contains(value),
248+
_ => false,
249+
}))
250+
} else {
251+
unreachable!()
252+
}
253+
}
254+
208255
fn range_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
209256
arg_check!(
210257
vm,

0 commit comments

Comments
 (0)