Skip to content

Commit 243a130

Browse files
committed
impl number protocol for PyInt
1 parent 68fa027 commit 243a130

File tree

3 files changed

+110
-44
lines changed

3 files changed

+110
-44
lines changed

vm/src/builtins/int.rs

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
common::hash,
55
format::FormatSpec,
66
function::{ArgIntoBool, IntoPyObject, IntoPyResult, OptionalArg, OptionalOption},
7-
protocol::PyNumberMethods,
7+
protocol::{PyNumber, PyNumberMethods},
88
try_value_from_borrowed_object,
99
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp},
1010
IdProtocol, PyArithmeticValue, PyClassImpl, PyComparisonValue, PyContext, PyObject,
@@ -745,47 +745,104 @@ impl AsNumber for PyInt {
745745
}
746746

747747
impl PyInt {
748+
fn number_protocol_binop<F>(
749+
number: &PyNumber,
750+
other: &PyObject,
751+
op: &str,
752+
f: F,
753+
vm: &VirtualMachine,
754+
) -> PyResult
755+
where
756+
F: FnOnce(&BigInt, &BigInt) -> BigInt,
757+
{
758+
let (a, b) = Self::downcast_or_binop_error(number, other, op, vm)?;
759+
let ret = f(&a.value, &b.value);
760+
Ok(vm.ctx.new_int(ret).into())
761+
}
762+
763+
fn number_protocol_int(number: &PyNumber, vm: &VirtualMachine) -> PyIntRef {
764+
if let Some(zelf) = number.obj.downcast_ref_if_exact::<Self>(vm) {
765+
zelf.to_owned()
766+
} else {
767+
let zelf = Self::number_downcast(number);
768+
vm.ctx.new_int(zelf.value)
769+
}
770+
}
771+
748772
const NUMBER_METHODS: PyNumberMethods = PyNumberMethods {
749773
add: Some(|number, other, vm| {
750-
let (a, b) = Self::downcast_or_binop_error(number, other, "+", vm)?;
751-
let ret = a.value + b.value;
752-
Ok(vm.ctx.new_int(ret).into())
774+
Self::number_protocol_binop(number, other, "+", |a, b| a + b, vm)
775+
}),
776+
subtract: Some(|number, other, vm| {
777+
Self::number_protocol_binop(number, other, "-", |a, b| a - b, vm)
778+
}),
779+
multiply: Some(|number, other, vm| {
780+
Self::number_protocol_binop(number, other, "*", |a, b| a * b, vm)
781+
}),
782+
remainder: Some(|number, other, vm| {
783+
let (a, b) = Self::downcast_or_binop_error(number, other, "%", vm)?;
784+
inner_mod(&a.value, &b.value, vm)
785+
}),
786+
divmod: Some(|number, other, vm| {
787+
let (a, b) = Self::downcast_or_binop_error(number, other, "divmod()", vm)?;
788+
inner_divmod(&a.value, &b.value, vm)
789+
}),
790+
power: Some(|number, other, vm| {
791+
let (a, b) = Self::downcast_or_binop_error(number, other, "** or pow()", vm)?;
792+
inner_pow(&a.value, &b.value, vm)
793+
}),
794+
negative: Some(|number, vm| {
795+
let zelf = Self::number_downcast(number);
796+
Ok(vm.ctx.new_int(-zelf.value).into())
797+
}),
798+
positive: Some(|number, vm| Ok(Self::number_protocol_int(number, vm).into())),
799+
absolute: Some(|number, vm| {
800+
let zelf = Self::number_downcast(number);
801+
Ok(vm.ctx.new_int(zelf.value.abs()).into())
802+
}),
803+
boolean: Some(|number, vm| {
804+
let zelf = Self::number_downcast(number);
805+
Ok(zelf.value.is_zero())
806+
}),
807+
invert: Some(|number, vm| {
808+
let zelf = Self::number_downcast(number);
809+
Ok(vm.ctx.new_int(!zelf.value).into())
810+
}),
811+
lshift: Some(|number, other, vm| {
812+
let (a, b) = Self::downcast_or_binop_error(number, other, "<<", vm)?;
813+
inner_shift(&a.value, &b.value, |a, b| a << b, vm)
814+
}),
815+
rshift: Some(|number, other, vm| {
816+
let (a, b) = Self::downcast_or_binop_error(number, other, ">>", vm)?;
817+
inner_shift(&a.value, &b.value, |a, b| a >> b, vm)
818+
}),
819+
and: Some(|number, other, vm| {
820+
Self::number_protocol_binop(number, other, "&", |a, b| a & b, vm)
821+
}),
822+
xor: Some(|number, other, vm| {
823+
Self::number_protocol_binop(number, other, "^", |a, b| a ^ b, vm)
824+
}),
825+
or: Some(|number, other, vm| {
826+
Self::number_protocol_binop(number, other, "|", |a, b| a | b, vm)
827+
}),
828+
int: Some(|number, other| Ok(Self::number_protocol_int(number, other))),
829+
float: Some(|number, vm| {
830+
let zelf = number
831+
.obj
832+
.downcast_ref::<Self>()
833+
.ok_or_else(|| vm.new_type_error("an integer is required".to_owned()))?;
834+
try_to_float(&zelf.value, vm).map(|x| vm.ctx.new_float(x))
835+
}),
836+
floor_divide: Some(|number, other, vm| {
837+
let (a, b) = Self::downcast_or_binop_error(number, other, "//", vm)?;
838+
inner_floordiv(&a.value, &b.value, vm)
839+
}),
840+
true_divide: Some(|number, other, vm| {
841+
let (a, b) = Self::downcast_or_binop_error(number, other, "/", vm)?;
842+
inner_truediv(&a.value, &b.value, vm)
753843
}),
754-
subtract: todo!(),
755-
multiply: todo!(),
756-
remainder: todo!(),
757-
divmod: todo!(),
758-
power: todo!(),
759-
negative: todo!(),
760-
positive: todo!(),
761-
absolute: todo!(),
762-
boolean: todo!(),
763-
invert: todo!(),
764-
lshift: todo!(),
765-
rshift: todo!(),
766-
and: todo!(),
767-
xor: todo!(),
768-
or: todo!(),
769-
int: todo!(),
770-
float: todo!(),
771-
inplace_add: todo!(),
772-
inplace_substract: todo!(),
773-
inplace_multiply: todo!(),
774-
inplace_remainder: todo!(),
775-
inplace_divmod: todo!(),
776-
inplace_power: todo!(),
777-
inplace_lshift: todo!(),
778-
inplace_rshift: todo!(),
779-
inplace_and: todo!(),
780-
inplace_xor: todo!(),
781-
inplace_or: todo!(),
782-
floor_divide: todo!(),
783-
true_divide: todo!(),
784-
inplace_floor_divide: todo!(),
785-
inplace_true_devide: todo!(),
786-
index: todo!(),
787-
matrix_multiply: todo!(),
788-
inplace_matrix_multiply: todo!(),
844+
index: Some(|number, vm| Ok(Self::number_protocol_int(number, vm))),
845+
..*PyNumberMethods::not_implemented()
789846
};
790847
}
791848

vm/src/stdlib/collections.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,19 @@ mod _collections {
529529
_zelf: &crate::PyObjectView<Self>,
530530
_vm: &VirtualMachine,
531531
) -> Cow<'static, crate::protocol::PyNumberMethods> {
532-
Cow::Borrowed(&NUMBER_METHODS)
532+
Cow::Borrowed(&Self::NUMBER_METHODS)
533533
}
534534
}
535535

536-
static NUMBER_METHODS: PyNumberMethods = PyNumberMethods {
537-
boolean: Some(|number, vm| Ok(number.try_unary_for::<PyDeque>("bool", vm)?.bool())),
538-
..*PyNumberMethods::not_implemented()
539-
};
536+
impl PyDeque {
537+
const NUMBER_METHODS: PyNumberMethods = PyNumberMethods {
538+
// boolean: Some(|number, vm| Ok(number.try_unary_for::<PyDeque>("bool", vm)?.bool())),
539+
boolean: Some(|number, vm| {
540+
Self::downcast_or_unary_error(number, "bool", vm).map(|x| x.bool())
541+
}),
542+
..*PyNumberMethods::not_implemented()
543+
};
544+
}
540545

541546
#[pyattr]
542547
#[pyclass(name = "_deque_iterator")]

vm/src/types/slot.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,10 @@ pub trait AsNumber: PyValue {
827827
.downcast_ref()
828828
.ok_or_else(|| vm.new_unsupported_unary_error(number.obj, op))
829829
}
830+
831+
fn number_downcast<'a>(number: &'a PyNumber) -> &'a PyObjectView<Self> {
832+
unsafe { number.obj.downcast_unchecked_ref() }
833+
}
830834
}
831835

832836
#[pyimpl]

0 commit comments

Comments
 (0)