From beeb4499f2f13e4eab24444d10a6e4968f36cbf9 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 11 Nov 2021 20:21:00 +0200 Subject: [PATCH 1/7] PyNumber protocol --- vm/src/builtins/int.rs | 2 +- vm/src/protocol/number.rs | 161 ++++++++++++++++++++++++++++++++++++++ vm/src/types/slot.rs | 2 + 3 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 vm/src/protocol/number.rs diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index ecca0b544e..be1b64d22b 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -807,7 +807,7 @@ fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult Option { +pub(crate) fn bytes_to_int(lit: &[u8], mut base: u32) -> Option { // split sign let mut lit = lit.trim(); let sign = match lit.first()? { diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs new file mode 100644 index 0000000000..aee744b271 --- /dev/null +++ b/vm/src/protocol/number.rs @@ -0,0 +1,161 @@ +use std::borrow::Cow; + +use crate::{ + builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr}, + common::{lock::OnceCell, static_cell}, + function::ArgBytesLike, + IdProtocol, PyObject, PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, + VirtualMachine, +}; + +#[allow(clippy::type_complexity)] +#[derive(Default, Clone)] +pub struct PyNumberMethods { + /* Number implementations must check *both* + arguments for proper type and implement the necessary conversions + in the slot functions themselves. */ + pub add: Option PyResult>, + pub subtract: Option PyResult>, + pub multiply: Option PyResult>, + pub remainder: Option PyResult>, + pub divmod: Option PyResult>, + pub power: Option PyResult>, + pub negative: Option PyResult>, + pub positive: Option PyResult>, + pub absolute: Option PyResult>, + pub boolean: Option PyResult>, + pub invert: Option PyResult>, + pub lshift: Option PyResult>, + pub rshift: Option PyResult>, + pub and: Option PyResult>, + pub xor: Option PyResult>, + pub or: Option PyResult>, + pub int: Option PyResult>, + pub float: Option PyResult>>, + + pub inplace_add: Option PyResult>, + pub inplace_substract: Option PyResult>, + pub inplace_multiply: Option PyResult>, + pub inplace_remainder: Option PyResult>, + pub inplace_divmod: Option PyResult>, + pub inplace_power: Option PyResult>, + pub inplace_lshift: Option PyResult>, + pub inplace_rshift: Option PyResult>, + pub inplace_and: Option PyResult>, + pub inplace_xor: Option PyResult>, + pub inplace_or: Option PyResult>, + + pub floor_divide: Option PyResult>, + pub true_divide: Option PyResult>, + pub inplace_floor_divide: Option PyResult>, + pub inplace_true_devide: Option PyResult>, + + pub index: Option PyResult>, + + pub matrix_multiply: Option PyResult>, + pub inplace_matrix_multiply: Option PyResult>, +} + +impl PyNumberMethods { + fn not_implemented() -> &'static Self { + static_cell! { + static NOT_IMPLEMENTED: PyNumberMethods; + } + NOT_IMPLEMENTED.get_or_init(Self::default) + } +} + +pub struct PyNumber<'a> { + pub obj: &'a PyObject, + // some fast path do not need methods, so we do lazy initialize + methods: OnceCell>, +} + +impl<'a> From<&'a PyObject> for PyNumber<'a> { + fn from(obj: &'a PyObject) -> Self { + Self { + obj, + methods: OnceCell::new(), + } + } +} + +impl<'a> PyNumber<'a> { + pub fn methods(&'a self, vm: &VirtualMachine) -> &'a Cow<'static, PyNumberMethods> { + self.methods.get_or_init(|| { + self.obj + .class() + .mro_find_map(|x| x.slots.as_number.load()) + .map(|f| f(self.obj, vm)) + .unwrap_or_else(|| Cow::Borrowed(PyNumberMethods::not_implemented())) + }) + } +} + +impl PyNumber<'_> { + // PyNumber_Check + pub fn is_numeric(&self, vm: &VirtualMachine) -> bool { + let methods = self.methods(vm); + methods.int.is_some() + || methods.index.is_some() + || methods.float.is_some() + || self.obj.payload_is::() + } + + // PyIndex_Check + pub fn is_index(&self, vm: &VirtualMachine) -> bool { + self.methods(vm).index.is_some() + } + + pub fn to_int(&self, vm: &VirtualMachine) -> PyResult { + fn try_convert(obj: &PyObject, lit: &[u8], vm: &VirtualMachine) -> PyResult { + let base = 10; + match int::bytes_to_int(lit, base) { + Some(i) => Ok(PyInt::from(i).into_ref(vm)), + None => Err(vm.new_value_error(format!( + "invalid literal for int() with base {}: {}", + base, + obj.repr(vm)?, + ))), + } + } + + if self.obj.class().is(PyInt::class(vm)) { + Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) + } else if let Some(f) = self.methods(vm).int { + f(self, vm) + } else if let Some(f) = self.methods(vm).index { + f(self, vm) + } else if let Ok(Ok(f)) = vm.get_special_method(self.obj.to_owned(), "__trunc__") { + let r = f.invoke((), vm)?; + PyNumber::from(r.as_ref()).to_index(vm) + } else if let Some(s) = self.obj.payload::() { + try_convert(self.obj, s.as_str().as_bytes(), vm) + } else if let Some(bytes) = self.obj.payload::() { + try_convert(self.obj, bytes, vm) + } else if let Some(bytearray) = self.obj.payload::() { + try_convert(self.obj, &bytearray.borrow_buf(), vm) + } else if let Ok(buffer) = ArgBytesLike::try_from_borrowed_object(vm, self.obj) { + // TODO: replace to PyBuffer + try_convert(self.obj, &buffer.borrow_buf(), vm) + } else { + Err(vm.new_type_error(format!( + "int() argument must be a string, a bytes-like object or a real number, not '{}'", + self.obj.class() + ))) + } + } + + pub fn to_index(&self, vm: &VirtualMachine) -> PyResult { + if self.obj.class().is(PyInt::class(vm)) { + Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) + } else if let Some(f) = self.methods(vm).index { + f(self, vm) + } else { + Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer", + self.obj.class() + ))) + } + } +} diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 245e7a81c8..0dac14a21c 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,4 +1,5 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; +use crate::protocol::PyNumberMethods; use crate::{ builtins::{PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, bytecode::ComparisonOperator, @@ -138,6 +139,7 @@ impl Default for PyTypeFlags { pub(crate) type GenericMethod = fn(&PyObject, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type AsMappingFunc = fn(&PyObject, &VirtualMachine) -> &'static PyMappingMethods; +pub(crate) type AsNumberFunc = fn(&PyObject, &VirtualMachine) -> Cow<'static, PyNumberMethods>; pub(crate) type HashFunc = fn(&PyObject, &VirtualMachine) -> PyResult; // CallFunc = GenericMethod pub(crate) type GetattroFunc = fn(&PyObject, PyStrRef, &VirtualMachine) -> PyResult; From 20632edc5dfe0286fc9f5f039d84fe37c6bd6ee9 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 12 Nov 2021 16:09:32 +0200 Subject: [PATCH 2/7] Number protocol for PyInt --- vm/src/builtins/int.rs | 175 +++++++++++++++++++++++++------------- vm/src/protocol/mod.rs | 2 + vm/src/protocol/number.rs | 153 ++++++++++++++++++++------------- vm/src/types/slot.rs | 76 ++++++++++++++++- 4 files changed, 285 insertions(+), 121 deletions(-) diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index be1b64d22b..524b6bc8b4 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -17,6 +17,7 @@ use bstr::ByteSlice; use num_bigint::{BigInt, BigUint, Sign}; use num_integer::Integer; use num_traits::{One, Pow, PrimInt, Signed, ToPrimitive, Zero}; +use std::borrow::Cow; use std::fmt; /// int(x=0) -> integer @@ -263,7 +264,10 @@ impl Constructor for PyInt { val }; - try_int(&val, vm) + // try_int(&val, vm) + PyNumber::from(val.as_ref()) + .int(vm) + .map(|x| x.as_bigint().clone()) } } else if let OptionalArg::Present(_) = options.base { Err(vm.new_type_error("int() missing string argument".to_owned())) @@ -346,7 +350,7 @@ impl PyInt { } } -#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor))] +#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor, AsNumber))] impl PyInt { #[pymethod(name = "__radd__")] #[pymethod(magic)] @@ -751,6 +755,114 @@ impl Hashable for PyInt { } } +impl AsNumber for PyInt { + fn as_number(_zelf: &crate::Py, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { + Cow::Borrowed(&Self::NUMBER_METHODS) + } +} + +impl PyInt { + fn number_protocol_binop( + number: &PyNumber, + other: &PyObject, + op: &str, + f: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: FnOnce(&BigInt, &BigInt) -> BigInt, + { + let (a, b) = Self::downcast_or_binop_error(number, other, op, vm)?; + let ret = f(&a.value, &b.value); + Ok(vm.ctx.new_int(ret).into()) + } + + fn number_protocol_int(number: &PyNumber, vm: &VirtualMachine) -> PyIntRef { + if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { + zelf.to_owned() + } else { + let zelf = Self::number_downcast(number); + vm.ctx.new_int(zelf.value.clone()) + } + } + + const NUMBER_METHODS: PyNumberMethods = PyNumberMethods { + add: Some(|number, other, vm| { + Self::number_protocol_binop(number, other, "+", |a, b| a + b, vm) + }), + subtract: Some(|number, other, vm| { + Self::number_protocol_binop(number, other, "-", |a, b| a - b, vm) + }), + multiply: Some(|number, other, vm| { + Self::number_protocol_binop(number, other, "*", |a, b| a * b, vm) + }), + remainder: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, "%", vm)?; + inner_mod(&a.value, &b.value, vm) + }), + divmod: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, "divmod()", vm)?; + inner_divmod(&a.value, &b.value, vm) + }), + power: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, "** or pow()", vm)?; + inner_pow(&a.value, &b.value, vm) + }), + negative: Some(|number, vm| { + let zelf = Self::number_downcast(number); + Ok(vm.ctx.new_int(-&zelf.value).into()) + }), + positive: Some(|number, vm| Ok(Self::number_protocol_int(number, vm).into())), + absolute: Some(|number, vm| { + let zelf = Self::number_downcast(number); + Ok(vm.ctx.new_int(zelf.value.abs()).into()) + }), + boolean: Some(|number, _vm| { + let zelf = Self::number_downcast(number); + Ok(zelf.value.is_zero()) + }), + invert: Some(|number, vm| { + let zelf = Self::number_downcast(number); + Ok(vm.ctx.new_int(!&zelf.value).into()) + }), + lshift: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, "<<", vm)?; + inner_lshift(&a.value, &b.value, vm) + }), + rshift: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, ">>", vm)?; + inner_rshift(&a.value, &b.value, vm) + }), + and: Some(|number, other, vm| { + Self::number_protocol_binop(number, other, "&", |a, b| a & b, vm) + }), + xor: Some(|number, other, vm| { + Self::number_protocol_binop(number, other, "^", |a, b| a ^ b, vm) + }), + or: Some(|number, other, vm| { + Self::number_protocol_binop(number, other, "|", |a, b| a | b, vm) + }), + int: Some(|number, other| Ok(Self::number_protocol_int(number, other))), + float: Some(|number, vm| { + let zelf = number + .obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("an integer is required".to_owned()))?; + try_to_float(&zelf.value, vm).map(|x| vm.ctx.new_float(x)) + }), + floor_divide: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, "//", vm)?; + inner_floordiv(&a.value, &b.value, vm) + }), + true_divide: Some(|number, other, vm| { + let (a, b) = Self::downcast_or_binop_error(number, other, "/", vm)?; + inner_truediv(&a.value, &b.value, vm) + }), + index: Some(|number, vm| Ok(Self::number_protocol_int(number, vm))), + ..*PyNumberMethods::not_implemented() + }; +} + #[derive(FromArgs)] pub struct IntOptions { #[pyarg(positional, optional)] @@ -926,65 +1038,6 @@ fn i2f(int: &BigInt) -> Option { int.to_f64().filter(|f| f.is_finite()) } -pub(crate) fn try_int(obj: &PyObject, vm: &VirtualMachine) -> PyResult { - fn try_convert(obj: &PyObject, lit: &[u8], vm: &VirtualMachine) -> PyResult { - let base = 10; - match bytes_to_int(lit, base) { - Some(i) => Ok(i), - None => Err(vm.new_value_error(format!( - "invalid literal for int() with base {}: {}", - base, - obj.repr(vm)?, - ))), - } - } - - // test for strings and bytes - if let Some(s) = obj.downcast_ref::() { - return try_convert(obj, s.as_str().as_bytes(), vm); - } - if let Ok(r) = obj.try_bytes_like(vm, |x| try_convert(obj, x, vm)) { - return r; - } - // strict `int` check - if let Some(int) = obj.payload_if_exact::(vm) { - return Ok(int.as_bigint().clone()); - } - // call __int__, then __index__, then __trunc__ (converting the __trunc__ result via __index__ if needed) - // TODO: using __int__ is deprecated and removed in Python 3.10 - if let Some(method) = vm.get_method(obj.to_owned(), identifier!(vm, __int__)) { - let result = vm.invoke(&method?, ())?; - return match result.payload::() { - Some(int_obj) => Ok(int_obj.as_bigint().clone()), - None => Err(vm.new_type_error(format!( - "__int__ returned non-int (type '{}')", - result.class().name() - ))), - }; - } - // TODO: returning strict subclasses of int in __index__ is deprecated - if let Some(r) = vm.to_index_opt(obj.to_owned()).transpose()? { - return Ok(r.as_bigint().clone()); - } - if let Some(method) = vm.get_method(obj.to_owned(), identifier!(vm, __trunc__)) { - let result = vm.invoke(&method?, ())?; - return vm - .to_index_opt(result.clone()) - .unwrap_or_else(|| { - Err(vm.new_type_error(format!( - "__trunc__ returned non-Integral (type {})", - result.class().name() - ))) - }) - .map(|int_obj| int_obj.as_bigint().clone()); - } - - Err(vm.new_type_error(format!( - "int() argument must be a string, a bytes-like object or a number, not '{}'", - obj.class().name() - ))) -} - pub(crate) fn init(context: &Context) { PyInt::extend_class(context, context.types.int_type); } diff --git a/vm/src/protocol/mod.rs b/vm/src/protocol/mod.rs index 9415bb455e..dae492628e 100644 --- a/vm/src/protocol/mod.rs +++ b/vm/src/protocol/mod.rs @@ -1,10 +1,12 @@ mod buffer; mod iter; mod mapping; +mod number; mod object; mod sequence; pub use buffer::{BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, VecBuffer}; pub use iter::{PyIter, PyIterIter, PyIterReturn}; pub use mapping::{PyMapping, PyMappingMethods}; +pub use number::{PyNumber, PyNumberMethods}; pub use sequence::{PySequence, PySequenceMethods}; diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index aee744b271..7cca188b13 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -2,66 +2,62 @@ use std::borrow::Cow; use crate::{ builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr}, - common::{lock::OnceCell, static_cell}, + common::lock::OnceCell, function::ArgBytesLike, - IdProtocol, PyObject, PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, - VirtualMachine, + AsObject, PyObject, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; #[allow(clippy::type_complexity)] -#[derive(Default, Clone)] +#[derive(Clone)] pub struct PyNumberMethods { /* Number implementations must check *both* arguments for proper type and implement the necessary conversions in the slot functions themselves. */ - pub add: Option PyResult>, - pub subtract: Option PyResult>, - pub multiply: Option PyResult>, - pub remainder: Option PyResult>, - pub divmod: Option PyResult>, - pub power: Option PyResult>, - pub negative: Option PyResult>, - pub positive: Option PyResult>, - pub absolute: Option PyResult>, - pub boolean: Option PyResult>, - pub invert: Option PyResult>, - pub lshift: Option PyResult>, - pub rshift: Option PyResult>, - pub and: Option PyResult>, - pub xor: Option PyResult>, - pub or: Option PyResult>, - pub int: Option PyResult>, - pub float: Option PyResult>>, - - pub inplace_add: Option PyResult>, - pub inplace_substract: Option PyResult>, - pub inplace_multiply: Option PyResult>, - pub inplace_remainder: Option PyResult>, - pub inplace_divmod: Option PyResult>, - pub inplace_power: Option PyResult>, - pub inplace_lshift: Option PyResult>, - pub inplace_rshift: Option PyResult>, - pub inplace_and: Option PyResult>, - pub inplace_xor: Option PyResult>, - pub inplace_or: Option PyResult>, - - pub floor_divide: Option PyResult>, - pub true_divide: Option PyResult>, - pub inplace_floor_divide: Option PyResult>, - pub inplace_true_devide: Option PyResult>, - - pub index: Option PyResult>, - - pub matrix_multiply: Option PyResult>, - pub inplace_matrix_multiply: Option PyResult>, + pub add: Option PyResult>, + pub subtract: Option PyResult>, + pub multiply: Option PyResult>, + pub remainder: Option PyResult>, + pub divmod: Option PyResult>, + pub power: Option PyResult>, + pub negative: Option PyResult>, + pub positive: Option PyResult>, + pub absolute: Option PyResult>, + pub boolean: Option PyResult>, + pub invert: Option PyResult>, + pub lshift: Option PyResult>, + pub rshift: Option PyResult>, + pub and: Option PyResult>, + pub xor: Option PyResult>, + pub or: Option PyResult>, + pub int: Option PyResult>, + pub float: Option PyResult>>, + + pub inplace_add: Option PyResult>, + pub inplace_substract: Option PyResult>, + pub inplace_multiply: Option PyResult>, + pub inplace_remainder: Option PyResult>, + pub inplace_divmod: Option PyResult>, + pub inplace_power: Option PyResult>, + pub inplace_lshift: Option PyResult>, + pub inplace_rshift: Option PyResult>, + pub inplace_and: Option PyResult>, + pub inplace_xor: Option PyResult>, + pub inplace_or: Option PyResult>, + + pub floor_divide: Option PyResult>, + pub true_divide: Option PyResult>, + pub inplace_floor_divide: Option PyResult>, + pub inplace_true_devide: Option PyResult>, + + pub index: Option PyResult>, + + pub matrix_multiply: Option PyResult>, + pub inplace_matrix_multiply: Option PyResult>, } impl PyNumberMethods { - fn not_implemented() -> &'static Self { - static_cell! { - static NOT_IMPLEMENTED: PyNumberMethods; - } - NOT_IMPLEMENTED.get_or_init(Self::default) + pub const fn not_implemented() -> &'static Self { + &NOT_IMPLEMENTED } } @@ -80,8 +76,12 @@ impl<'a> From<&'a PyObject> for PyNumber<'a> { } } -impl<'a> PyNumber<'a> { - pub fn methods(&'a self, vm: &VirtualMachine) -> &'a Cow<'static, PyNumberMethods> { +impl PyNumber<'_> { + pub fn methods(&self, vm: &VirtualMachine) -> &PyNumberMethods { + &*self.methods_cow(vm) + } + + pub fn methods_cow(&self, vm: &VirtualMachine) -> &Cow<'static, PyNumberMethods> { self.methods.get_or_init(|| { self.obj .class() @@ -90,11 +90,9 @@ impl<'a> PyNumber<'a> { .unwrap_or_else(|| Cow::Borrowed(PyNumberMethods::not_implemented())) }) } -} -impl PyNumber<'_> { // PyNumber_Check - pub fn is_numeric(&self, vm: &VirtualMachine) -> bool { + pub fn check(&self, vm: &VirtualMachine) -> bool { let methods = self.methods(vm); methods.int.is_some() || methods.index.is_some() @@ -107,7 +105,7 @@ impl PyNumber<'_> { self.methods(vm).index.is_some() } - pub fn to_int(&self, vm: &VirtualMachine) -> PyResult { + pub fn int(&self, vm: &VirtualMachine) -> PyResult { fn try_convert(obj: &PyObject, lit: &[u8], vm: &VirtualMachine) -> PyResult { let base = 10; match int::bytes_to_int(lit, base) { @@ -128,7 +126,9 @@ impl PyNumber<'_> { f(self, vm) } else if let Ok(Ok(f)) = vm.get_special_method(self.obj.to_owned(), "__trunc__") { let r = f.invoke((), vm)?; - PyNumber::from(r.as_ref()).to_index(vm) + PyNumber::from(r.as_ref()).index(vm).map_err(|_| { + vm.new_type_error("__trunc__ returned non-Integral (type NonIntegral)".to_string()) + }) } else if let Some(s) = self.obj.payload::() { try_convert(self.obj, s.as_str().as_bytes(), vm) } else if let Some(bytes) = self.obj.payload::() { @@ -146,7 +146,7 @@ impl PyNumber<'_> { } } - pub fn to_index(&self, vm: &VirtualMachine) -> PyResult { + pub fn index(&self, vm: &VirtualMachine) -> PyResult { if self.obj.class().is(PyInt::class(vm)) { Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) } else if let Some(f) = self.methods(vm).index { @@ -159,3 +159,42 @@ impl PyNumber<'_> { } } } + +const NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods { + add: None, + subtract: None, + multiply: None, + remainder: None, + divmod: None, + power: None, + negative: None, + positive: None, + absolute: None, + boolean: None, + invert: None, + lshift: None, + rshift: None, + and: None, + xor: None, + or: None, + int: None, + float: None, + inplace_add: None, + inplace_substract: None, + inplace_multiply: None, + inplace_remainder: None, + inplace_divmod: None, + inplace_power: None, + inplace_lshift: None, + inplace_rshift: None, + inplace_and: None, + inplace_xor: None, + inplace_or: None, + floor_divide: None, + true_divide: None, + inplace_floor_divide: None, + inplace_true_devide: None, + index: None, + matrix_multiply: None, + inplace_matrix_multiply: None, +}; diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 0dac14a21c..bd9d3b0571 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,5 +1,4 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; -use crate::protocol::PyNumberMethods; use crate::{ builtins::{PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, bytecode::ComparisonOperator, @@ -8,7 +7,8 @@ use crate::{ function::{FromArgs, FuncArgs, OptionalArg, PyComparisonValue}, identifier, protocol::{ - PyBuffer, PyIterReturn, PyMapping, PyMappingMethods, PySequence, PySequenceMethods, + PyBuffer, PyIterReturn, PyMapping, PyMappingMethods, PyNumber, PyNumberMethods, PySequence, + PySequenceMethods, }, vm::Context, AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, @@ -30,7 +30,7 @@ pub struct PyTypeSlots { // Methods to implement standard operations // Method suites for standard classes - // tp_as_number + pub as_number: AtomicCell>, pub as_sequence: AtomicCell>, pub as_mapping: AtomicCell>, @@ -340,6 +340,39 @@ fn as_sequence_generic(zelf: &PyObject, vm: &VirtualMachine) -> &'static PySeque static_as_sequence_generic(has_length, has_ass_item) } +fn as_number_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { + Cow::Owned(PyNumberMethods { + int: then_some_closure!(zelf.class().has_attr("__int__"), |num, vm| { + let ret = vm.call_special_method(num.obj.to_owned(), "__int__", ())?; + if let Some(i) = ret.downcast_ref::() { + Ok(i.to_owned()) + } else { + // TODO + Err(vm.new_type_error("".to_string())) + } + }), + float: then_some_closure!(zelf.class().has_attr("__float__"), |num, vm| { + let ret = vm.call_special_method(num.obj.to_owned(), "__float__", ())?; + if let Some(f) = ret.downcast_ref::() { + Ok(f.to_owned()) + } else { + // TODO + Err(vm.new_type_error("".to_string())) + } + }), + index: then_some_closure!(zelf.class().has_attr("__index__"), |num, vm| { + let ret = vm.call_special_method(num.obj.to_owned(), "__index__", ())?; + if let Some(i) = ret.downcast_ref::() { + Ok(i.to_owned()) + } else { + // TODO + Err(vm.new_type_error("".to_string())) + } + }), + ..*PyNumberMethods::not_implemented() + }) +} + fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { let hash_obj = vm.call_special_method(zelf.to_owned(), identifier!(vm, __hash__), ())?; match hash_obj.payload_if_subclass::(vm) { @@ -491,6 +524,9 @@ impl PyType { "__del__" => { update_slot!(del, del_wrapper); } + "__int__" | "__index__" | "__float__" => { + update_slot!(as_number, as_number_wrapper); + } _ => {} } } @@ -993,6 +1029,40 @@ pub trait AsSequence: PyPayload { } } +#[pyimpl] +pub trait AsNumber: PyPayload { + #[inline] + #[pyslot] + fn slot_as_number(zelf: &PyObject, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { + let zelf = unsafe { zelf.downcast_unchecked_ref::() }; + Self::as_number(zelf, vm) + } + + fn as_number(zelf: &Py, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods>; + + fn number_downcast<'a>(number: &'a PyNumber) -> &'a Py { + unsafe { number.obj.downcast_unchecked_ref::() } + } + + fn downcast_or_binop_error<'a, 'b>( + a: &'a PyNumber, + b: &'b PyObject, + op: &str, + vm: &VirtualMachine, + ) -> PyResult<(&'a Self, &'b Self)> { + if let (Some(a), Some(b)) = (a.obj.payload::(), b.payload::()) { + Ok((a, b)) + } else { + Err(vm.new_type_error(format!( + "unsupported operand type(s) for {}: '{}' and '{}'", + op, + a.obj.class(), + b.class() + ))) + } + } +} + #[pyimpl] pub trait Iterable: PyPayload { #[pyslot] From 59cedd2213dffb3b1411c9db6a094109c9a8f6b3 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 30 May 2022 07:18:00 +0900 Subject: [PATCH 3/7] deprecated warnings for int --- Lib/test/test_int.py | 2 -- vm/src/protocol/number.rs | 47 +++++++++++++++++++++++++++++++-------- vm/src/stdlib/mod.rs | 2 +- vm/src/stdlib/warnings.rs | 20 +++++++++++++++++ vm/src/types/slot.rs | 30 ++++++++++--------------- 5 files changed, 71 insertions(+), 30 deletions(-) diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index 87cde63c7c..7ac83288f4 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -458,8 +458,6 @@ def __int__(self): self.assertEqual(my_int, 7) self.assertRaises(TypeError, int, my_int) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_int_returns_int_subclass(self): class BadIndex: def __index__(self): diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 7cca188b13..160f4cd438 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -4,6 +4,7 @@ use crate::{ builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr}, common::lock::OnceCell, function::ArgBytesLike, + stdlib::warnings, AsObject, PyObject, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; @@ -119,15 +120,34 @@ impl PyNumber<'_> { } if self.obj.class().is(PyInt::class(vm)) { - Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) + Ok(unsafe { self.obj.to_owned().downcast_unchecked::() }) } else if let Some(f) = self.methods(vm).int { - f(self, vm) - } else if let Some(f) = self.methods(vm).index { - f(self, vm) + let ret = f(self, vm)?; + if !ret.class().is(PyInt::class(vm)) { + warnings::warn( + vm.ctx.exceptions.deprecation_warning.clone(), + format!("__int__ returned non-int (type {})", ret.class()), + 1, + vm, + )? + } + Ok(ret) + } else if self.methods(vm).index.is_some() { + self.index(vm) } else if let Ok(Ok(f)) = vm.get_special_method(self.obj.to_owned(), "__trunc__") { - let r = f.invoke((), vm)?; - PyNumber::from(r.as_ref()).index(vm).map_err(|_| { - vm.new_type_error("__trunc__ returned non-Integral (type NonIntegral)".to_string()) + // TODO: Deprecate in 3.11 + // warnings::warn( + // vm.ctx.exceptions.deprecation_warning.clone(), + // "The delegation of int() to __trunc__ is deprecated.".to_owned(), + // 1, + // vm, + // )?; + let ret = f.invoke((), vm)?; + PyNumber::from(ret.as_ref()).index(vm).map_err(|_| { + vm.new_type_error(format!( + "__trunc__ returned non-Integral (type {})", + ret.class() + )) }) } else if let Some(s) = self.obj.payload::() { try_convert(self.obj, s.as_str().as_bytes(), vm) @@ -148,9 +168,18 @@ impl PyNumber<'_> { pub fn index(&self, vm: &VirtualMachine) -> PyResult { if self.obj.class().is(PyInt::class(vm)) { - Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) + Ok(unsafe { self.obj.to_owned().downcast_unchecked::() }) } else if let Some(f) = self.methods(vm).index { - f(self, vm) + let ret = f(self, vm)?; + if !ret.class().is(PyInt::class(vm)) { + warnings::warn( + vm.ctx.exceptions.deprecation_warning.clone(), + format!("__index__ returned non-int (type {})", ret.class()), + 1, + vm, + )? + } + Ok(ret) } else { Err(vm.new_type_error(format!( "'{}' object cannot be interpreted as an integer", diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 8e133a0764..c791dafc6e 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -21,7 +21,7 @@ mod sysconfigdata; #[cfg(feature = "threading")] mod thread; pub mod time; -mod warnings; +pub mod warnings; mod weakref; #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] diff --git a/vm/src/stdlib/warnings.rs b/vm/src/stdlib/warnings.rs index edc3b108e1..e9b5f30eff 100644 --- a/vm/src/stdlib/warnings.rs +++ b/vm/src/stdlib/warnings.rs @@ -1,5 +1,25 @@ pub(crate) use _warnings::make_module; +use crate::{builtins::PyTypeRef, PyResult, VirtualMachine}; + +pub fn warn( + category: PyTypeRef, + message: String, + stack_level: usize, + vm: &VirtualMachine, +) -> PyResult<()> { + // let module = vm.import("warnings", None, 0)?; + // let func = module.get_attr("warn", vm)?; + // vm.invoke(&func, (message, category, stack_level))?; + // TODO + if let Ok(module) = vm.import("warnings", None, 0) { + if let Ok(func) = module.get_attr("warn", vm) { + let _ = vm.invoke(&func, (message, category, stack_level)); + } + } + Ok(()) +} + #[pymodule] mod _warnings { use crate::{ diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index bd9d3b0571..3811f0c972 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -344,30 +344,24 @@ fn as_number_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PyNu Cow::Owned(PyNumberMethods { int: then_some_closure!(zelf.class().has_attr("__int__"), |num, vm| { let ret = vm.call_special_method(num.obj.to_owned(), "__int__", ())?; - if let Some(i) = ret.downcast_ref::() { - Ok(i.to_owned()) - } else { - // TODO - Err(vm.new_type_error("".to_string())) - } + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) + }) }), float: then_some_closure!(zelf.class().has_attr("__float__"), |num, vm| { let ret = vm.call_special_method(num.obj.to_owned(), "__float__", ())?; - if let Some(f) = ret.downcast_ref::() { - Ok(f.to_owned()) - } else { - // TODO - Err(vm.new_type_error("".to_string())) - } + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "__float__ returned non-float (type {})", + obj.class() + )) + }) }), index: then_some_closure!(zelf.class().has_attr("__index__"), |num, vm| { let ret = vm.call_special_method(num.obj.to_owned(), "__index__", ())?; - if let Some(i) = ret.downcast_ref::() { - Ok(i.to_owned()) - } else { - // TODO - Err(vm.new_type_error("".to_string())) - } + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) + }) }), ..*PyNumberMethods::not_implemented() }) From 3e6e348a67920da8805bdbb62a236d0a150d9e17 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 25 May 2022 21:12:16 +0200 Subject: [PATCH 4/7] impl number protocol float --- Lib/test/test_float.py | 2 - stdlib/src/math.rs | 8 +-- vm/src/buffer.rs | 8 +-- vm/src/builtins/complex.rs | 4 +- vm/src/builtins/float.rs | 125 +++++++++++++++++++++++++------------ vm/src/builtins/int.rs | 118 +++++++++++++--------------------- vm/src/function/number.rs | 6 +- vm/src/protocol/number.rs | 62 ++++++++++++++++-- vm/src/stdlib/warnings.rs | 6 +- vm/src/types/slot.rs | 78 +++++++++++------------ 10 files changed, 236 insertions(+), 181 deletions(-) diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py index 39208dab9a..bc3d5bdc7c 100644 --- a/Lib/test/test_float.py +++ b/Lib/test/test_float.py @@ -177,8 +177,6 @@ def test_float_with_comma(self): self.assertEqual(float(" 25.e-1 "), 2.5) self.assertAlmostEqual(float(" .25e-1 "), .025) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_floatconversion(self): # Make sure that calls to __float__() work properly class Foo1(object): diff --git a/stdlib/src/math.rs b/stdlib/src/math.rs index e118ce77c7..08f450323e 100644 --- a/stdlib/src/math.rs +++ b/stdlib/src/math.rs @@ -454,8 +454,8 @@ mod math { fn ceil(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { let result_or_err = try_magic_method(identifier!(vm, __ceil__), vm, &x); if result_or_err.is_err() { - if let Ok(Some(v)) = x.try_to_f64(vm) { - let v = try_f64_to_bigint(v.ceil(), vm)?; + if let Ok(Some(v)) = x.try_float_opt(vm) { + let v = try_f64_to_bigint(v.to_f64().ceil(), vm)?; return Ok(vm.ctx.new_int(v).into()); } } @@ -466,8 +466,8 @@ mod math { fn floor(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { let result_or_err = try_magic_method(identifier!(vm, __floor__), vm, &x); if result_or_err.is_err() { - if let Ok(Some(v)) = x.try_to_f64(vm) { - let v = try_f64_to_bigint(v.floor(), vm)?; + if let Ok(Some(v)) = x.try_float_opt(vm) { + let v = try_f64_to_bigint(v.to_f64().floor(), vm)?; return Ok(vm.ctx.new_int(v).into()); } } diff --git a/vm/src/buffer.rs b/vm/src/buffer.rs index 046162fb14..f2e05e828c 100644 --- a/vm/src/buffer.rs +++ b/vm/src/buffer.rs @@ -1,8 +1,8 @@ use crate::{ - builtins::{float, PyBaseExceptionRef, PyBytesRef, PyTuple, PyTupleRef, PyTypeRef}, + builtins::{PyBaseExceptionRef, PyBytesRef, PyTuple, PyTupleRef, PyTypeRef}, common::{static_cell, str::wchar_t}, convert::ToPyObject, - function::{ArgBytesLike, ArgIntoBool}, + function::{ArgBytesLike, ArgIntoBool, ArgIntoFloat}, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use half::f16; @@ -521,7 +521,7 @@ macro_rules! make_pack_float { arg: PyObjectRef, data: &mut [u8], ) -> PyResult<()> { - let f = float::try_float(&arg, vm)? as $T; + let f = *ArgIntoFloat::try_from_object(vm, arg)? as $T; f.to_bits().pack_int::(data); Ok(()) } @@ -539,7 +539,7 @@ make_pack_float!(f64); impl Packable for f16 { fn pack(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> { - let f_64 = float::try_float(&arg, vm)?; + let f_64 = *ArgIntoFloat::try_from_object(vm, arg)?; let f_16 = f16::from_f64(f_64); if f_16.is_infinite() != f_64.is_infinite() { return Err(vm.new_overflow_error("float too large to pack with e format".to_owned())); diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index 2cfc84c7bb..9e4f5af904 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -65,8 +65,8 @@ impl PyObjectRef { if let Some(complex) = self.payload_if_subclass::(vm) { return Ok(Some((complex.value, true))); } - if let Some(float) = self.try_to_f64(vm)? { - return Ok(Some((Complex64::new(float, 0.0), false))); + if let Some(float) = self.try_float_opt(vm)? { + return Ok(Some((Complex64::new(float.to_f64(), 0.0), false))); } Ok(None) } diff --git a/vm/src/builtins/float.rs b/vm/src/builtins/float.rs index 5453971cbb..7712921170 100644 --- a/vm/src/builtins/float.rs +++ b/vm/src/builtins/float.rs @@ -1,18 +1,20 @@ +use std::borrow::Cow; + use super::{ try_bigint_to_f64, PyByteArray, PyBytes, PyInt, PyIntRef, PyStr, PyStrRef, PyType, PyTypeRef, }; -use crate::common::{float_ops, hash}; use crate::{ class::PyClassImpl, - convert::ToPyObject, + common::{float_ops, hash}, + convert::{ToPyObject, ToPyResult}, format::FormatSpec, function::{ ArgBytesLike, OptionalArg, OptionalOption, PyArithmeticValue::{self, *}, PyComparisonValue, }, - identifier, - types::{Comparable, Constructor, Hashable, PyComparisonOp}, + protocol::{PyNumber, PyNumberMethods}, + types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp}, AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, TryFromObject, VirtualMachine, }; @@ -58,32 +60,13 @@ impl From for PyFloat { } impl PyObject { - pub fn try_to_f64(&self, vm: &VirtualMachine) -> PyResult> { - if let Some(float) = self.payload_if_exact::(vm) { - return Ok(Some(float.value)); - } - if let Some(method) = vm.get_method(self.to_owned(), identifier!(vm, __float__)) { - let result = vm.invoke(&method?, ())?; - // TODO: returning strict subclasses of float in __float__ is deprecated - return match result.payload::() { - Some(float_obj) => Ok(Some(float_obj.value)), - None => Err(vm.new_type_error(format!( - "__float__ returned non-float (type '{}')", - result.class().name() - ))), - }; - } - if let Some(r) = vm.to_index_opt(self.to_owned()).transpose()? { - return Ok(Some(try_bigint_to_f64(r.as_bigint(), vm)?)); - } - Ok(None) + pub fn try_float_opt(&self, vm: &VirtualMachine) -> PyResult>> { + PyNumber::from(self).float_opt(vm) } -} -pub fn try_float(obj: &PyObject, vm: &VirtualMachine) -> PyResult { - obj.try_to_f64(vm)?.ok_or_else(|| { - vm.new_type_error(format!("must be real number, not {}", obj.class().name())) - }) + pub fn try_float(&self, vm: &VirtualMachine) -> PyResult> { + PyNumber::from(self).float(vm) + } } pub(crate) fn to_op_float(obj: &PyObject, vm: &VirtualMachine) -> PyResult> { @@ -170,17 +153,10 @@ impl Constructor for PyFloat { let float_val = match arg { OptionalArg::Missing => 0.0, OptionalArg::Present(val) => { - let val = if cls.is(vm.ctx.types.float_type) { - match val.downcast_exact::(vm) { - Ok(f) => return Ok(f.into()), - Err(val) => val, - } - } else { - val - }; - - if let Some(f) = val.try_to_f64(vm)? { - f + if cls.is(vm.ctx.types.float_type) && val.class().is(PyFloat::class(vm)) { + unsafe { val.downcast_unchecked::().value } + } else if let Some(f) = val.try_float_opt(vm)? { + f.value } else { float_from_string(val, vm)? } @@ -220,7 +196,7 @@ fn float_from_string(val: PyObjectRef, vm: &VirtualMachine) -> PyResult { }) } -#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor))] +#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor, AsNumber))] impl PyFloat { #[pymethod(magic)] fn format(&self, spec: PyStrRef, vm: &VirtualMachine) -> PyResult { @@ -562,6 +538,75 @@ impl Hashable for PyFloat { } } +impl AsNumber for PyFloat { + fn as_number(_zelf: &crate::Py, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { + Cow::Borrowed(&Self::NUMBER_METHODS) + } +} + +impl PyFloat { + fn np_general_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: FnOnce(f64, f64, &VirtualMachine) -> R, + R: ToPyResult, + { + if let (Some(a), Some(b)) = (to_op_float(number.obj, vm)?, to_op_float(other, vm)?) { + op(a, b, vm).to_pyresult(vm) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + fn np_float_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce(f64, f64) -> f64, + { + Self::np_general_op(number, other, |a, b, _vm| op(a, b), vm) + } + + fn np_float(number: &PyNumber, vm: &VirtualMachine) -> PyRef { + if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { + zelf.to_owned() + } else { + vm.ctx.new_float(Self::number_downcast(number).value) + } + } + + const NUMBER_METHODS: PyNumberMethods = PyNumberMethods { + add: Some(|number, other, vm| Self::np_float_op(number, other, |a, b| a + b, vm)), + subtract: Some(|number, other, vm| Self::np_float_op(number, other, |a, b| a - b, vm)), + multiply: Some(|number, other, vm| Self::np_float_op(number, other, |a, b| a * b, vm)), + remainder: Some(|number, other, vm| Self::np_general_op(number, other, inner_mod, vm)), + divmod: Some(|number, other, vm| Self::np_general_op(number, other, inner_divmod, vm)), + power: Some(|number, other, vm| Self::np_general_op(number, other, float_pow, vm)), + negative: Some(|number, vm| { + let value = Self::number_downcast(number).value; + (-value).to_pyresult(vm) + }), + positive: Some(|number, vm| Self::np_float(number, vm).to_pyresult(vm)), + absolute: Some(|number, vm| { + let value = Self::number_downcast(number).value; + value.abs().to_pyresult(vm) + }), + boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), + int: Some(|number, vm| { + let value = Self::number_downcast(number).value; + try_to_bigint(value, vm).map(|x| vm.ctx.new_int(x)) + }), + float: Some(|number, vm| Ok(Self::np_float(number, vm))), + floor_divide: Some(|number, other, vm| { + Self::np_general_op(number, other, inner_floordiv, vm) + }), + true_divide: Some(|number, other, vm| Self::np_general_op(number, other, inner_div, vm)), + ..*PyNumberMethods::not_implemented() + }; +} + // Retrieve inner float value: pub(crate) fn get_value(obj: &PyObject) -> f64 { obj.payload::().unwrap().value diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index 524b6bc8b4..c4baea3a2f 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -9,7 +9,8 @@ use crate::{ ArgByteOrder, ArgIntoBool, OptionalArg, OptionalOption, PyArithmeticValue, PyComparisonValue, }, - types::{Comparable, Constructor, Hashable, PyComparisonOp}, + protocol::{PyNumber, PyNumberMethods}, + types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp}, AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; @@ -17,8 +18,8 @@ use bstr::ByteSlice; use num_bigint::{BigInt, BigUint, Sign}; use num_integer::Integer; use num_traits::{One, Pow, PrimInt, Signed, ToPrimitive, Zero}; -use std::borrow::Cow; use std::fmt; +use std::{borrow::Cow, ops::Neg}; /// int(x=0) -> integer /// int(x, base=10) -> integer @@ -264,7 +265,6 @@ impl Constructor for PyInt { val }; - // try_int(&val, vm) PyNumber::from(val.as_ref()) .int(vm) .map(|x| x.as_bigint().clone()) @@ -762,22 +762,25 @@ impl AsNumber for PyInt { } impl PyInt { - fn number_protocol_binop( - number: &PyNumber, - other: &PyObject, - op: &str, - f: F, - vm: &VirtualMachine, - ) -> PyResult + fn np_general_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce(&BigInt, &BigInt, &VirtualMachine) -> PyResult, + { + if let (Some(a), Some(b)) = (number.obj.payload::(), other.payload::()) { + op(&a.value, &b.value, vm) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + fn np_int_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult where F: FnOnce(&BigInt, &BigInt) -> BigInt, { - let (a, b) = Self::downcast_or_binop_error(number, other, op, vm)?; - let ret = f(&a.value, &b.value); - Ok(vm.ctx.new_int(ret).into()) + Self::np_general_op(number, other, |a, b, _vm| op(a, b).to_pyresult(vm), vm) } - fn number_protocol_int(number: &PyNumber, vm: &VirtualMachine) -> PyIntRef { + fn np_int(number: &PyNumber, vm: &VirtualMachine) -> PyIntRef { if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { zelf.to_owned() } else { @@ -787,78 +790,43 @@ impl PyInt { } const NUMBER_METHODS: PyNumberMethods = PyNumberMethods { - add: Some(|number, other, vm| { - Self::number_protocol_binop(number, other, "+", |a, b| a + b, vm) - }), - subtract: Some(|number, other, vm| { - Self::number_protocol_binop(number, other, "-", |a, b| a - b, vm) - }), - multiply: Some(|number, other, vm| { - Self::number_protocol_binop(number, other, "*", |a, b| a * b, vm) - }), - remainder: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, "%", vm)?; - inner_mod(&a.value, &b.value, vm) - }), - divmod: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, "divmod()", vm)?; - inner_divmod(&a.value, &b.value, vm) - }), - power: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, "** or pow()", vm)?; - inner_pow(&a.value, &b.value, vm) - }), + add: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a + b, vm)), + subtract: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a - b, vm)), + multiply: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a * b, vm)), + remainder: Some(|number, other, vm| Self::np_general_op(number, other, inner_mod, vm)), + divmod: Some(|number, other, vm| Self::np_general_op(number, other, inner_divmod, vm)), + power: Some(|number, other, vm| Self::np_general_op(number, other, inner_pow, vm)), negative: Some(|number, vm| { - let zelf = Self::number_downcast(number); - Ok(vm.ctx.new_int(-&zelf.value).into()) - }), - positive: Some(|number, vm| Ok(Self::number_protocol_int(number, vm).into())), - absolute: Some(|number, vm| { - let zelf = Self::number_downcast(number); - Ok(vm.ctx.new_int(zelf.value.abs()).into()) - }), - boolean: Some(|number, _vm| { - let zelf = Self::number_downcast(number); - Ok(zelf.value.is_zero()) + Self::number_downcast(number) + .value + .clone() + .neg() + .to_pyresult(vm) }), + positive: Some(|number, vm| Ok(Self::np_int(number, vm).into())), + absolute: Some(|number, vm| Self::number_downcast(number).value.abs().to_pyresult(vm)), + boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), invert: Some(|number, vm| { - let zelf = Self::number_downcast(number); - Ok(vm.ctx.new_int(!&zelf.value).into()) + let value = Self::number_downcast(number).value.clone(); + (!value).to_pyresult(vm) }), - lshift: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, "<<", vm)?; - inner_lshift(&a.value, &b.value, vm) - }), - rshift: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, ">>", vm)?; - inner_rshift(&a.value, &b.value, vm) - }), - and: Some(|number, other, vm| { - Self::number_protocol_binop(number, other, "&", |a, b| a & b, vm) - }), - xor: Some(|number, other, vm| { - Self::number_protocol_binop(number, other, "^", |a, b| a ^ b, vm) - }), - or: Some(|number, other, vm| { - Self::number_protocol_binop(number, other, "|", |a, b| a | b, vm) - }), - int: Some(|number, other| Ok(Self::number_protocol_int(number, other))), + lshift: Some(|number, other, vm| Self::np_general_op(number, other, inner_lshift, vm)), + rshift: Some(|number, other, vm| Self::np_general_op(number, other, inner_rshift, vm)), + and: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a & b, vm)), + xor: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a ^ b, vm)), + or: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a | b, vm)), + int: Some(|number, other| Ok(Self::np_int(number, other))), float: Some(|number, vm| { - let zelf = number - .obj - .downcast_ref::() - .ok_or_else(|| vm.new_type_error("an integer is required".to_owned()))?; + let zelf = Self::number_downcast(number); try_to_float(&zelf.value, vm).map(|x| vm.ctx.new_float(x)) }), floor_divide: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, "//", vm)?; - inner_floordiv(&a.value, &b.value, vm) + Self::np_general_op(number, other, inner_floordiv, vm) }), true_divide: Some(|number, other, vm| { - let (a, b) = Self::downcast_or_binop_error(number, other, "/", vm)?; - inner_truediv(&a.value, &b.value, vm) + Self::np_general_op(number, other, inner_truediv, vm) }), - index: Some(|number, vm| Ok(Self::number_protocol_int(number, vm))), + index: Some(|number, vm| Ok(Self::np_int(number, vm))), ..*PyNumberMethods::not_implemented() }; } diff --git a/vm/src/function/number.rs b/vm/src/function/number.rs index 96ff57827b..62922c85ba 100644 --- a/vm/src/function/number.rs +++ b/vm/src/function/number.rs @@ -1,4 +1,4 @@ -use crate::{AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine}; +use crate::{protocol::PyNumber, AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine}; use num_complex::Complex64; use std::ops::Deref; @@ -82,9 +82,7 @@ impl Deref for ArgIntoFloat { impl TryFromObject for ArgIntoFloat { // Equivalent to PyFloat_AsDouble. fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let value = obj.try_to_f64(vm)?.ok_or_else(|| { - vm.new_type_error(format!("must be real number, not {}", obj.class().name())) - })?; + let value = PyNumber::from(obj.as_ref()).float(vm)?.to_f64(); Ok(ArgIntoFloat { value }) } } diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 160f4cd438..68b575c55d 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -125,8 +125,13 @@ impl PyNumber<'_> { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { warnings::warn( - vm.ctx.exceptions.deprecation_warning.clone(), - format!("__int__ returned non-int (type {})", ret.class()), + vm.ctx.exceptions.deprecation_warning, + format!( + "__int__ returned non-int (type {}). \ + The ability to return an instance of a strict subclass of int \ + is deprecated, and may be removed in a future version of Python.", + ret.class() + ), 1, vm, )? @@ -134,7 +139,9 @@ impl PyNumber<'_> { Ok(ret) } else if self.methods(vm).index.is_some() { self.index(vm) - } else if let Ok(Ok(f)) = vm.get_special_method(self.obj.to_owned(), "__trunc__") { + } else if let Ok(Ok(f)) = + vm.get_special_method(self.obj.to_owned(), identifier!(vm, __trunc__)) + { // TODO: Deprecate in 3.11 // warnings::warn( // vm.ctx.exceptions.deprecation_warning.clone(), @@ -173,8 +180,13 @@ impl PyNumber<'_> { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { warnings::warn( - vm.ctx.exceptions.deprecation_warning.clone(), - format!("__index__ returned non-int (type {})", ret.class()), + vm.ctx.exceptions.deprecation_warning, + format!( + "__index__ returned non-int (type {}). \ + The ability to return an instance of a strict subclass of int \ + is deprecated, and may be removed in a future version of Python.", + ret.class() + ), 1, vm, )? @@ -187,6 +199,46 @@ impl PyNumber<'_> { ))) } } + + pub fn float_opt(&self, vm: &VirtualMachine) -> PyResult>> { + if self.obj.class().is(PyFloat::class(vm)) { + Ok(Some(unsafe { + self.obj.to_owned().downcast_unchecked::() + })) + } else if let Some(f) = self.methods(vm).float { + let ret = f(self, vm)?; + if !ret.class().is(PyFloat::class(vm)) { + warnings::warn( + vm.ctx.exceptions.deprecation_warning, + format!( + "__float__ returned non-float (type {}). \ + The ability to return an instance of a strict subclass of float \ + is deprecated, and may be removed in a future version of Python.", + ret.class() + ), + 1, + vm, + )?; + Ok(Some(vm.ctx.new_float(ret.to_f64()))) + } else { + Ok(Some(ret)) + } + } else if self.methods(vm).index.is_some() { + let i = self.index(vm)?; + let value = int::try_to_float(i.as_bigint(), vm)?; + Ok(Some(vm.ctx.new_float(value))) + } else if let Some(value) = self.obj.downcast_ref::() { + Ok(Some(vm.ctx.new_float(value.to_f64()))) + } else { + Ok(None) + } + } + + pub fn float(&self, vm: &VirtualMachine) -> PyResult> { + self.float_opt(vm)?.ok_or_else(|| { + vm.new_type_error(format!("must be real number, not {}", self.obj.class())) + }) + } } const NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods { diff --git a/vm/src/stdlib/warnings.rs b/vm/src/stdlib/warnings.rs index e9b5f30eff..4bf14ef228 100644 --- a/vm/src/stdlib/warnings.rs +++ b/vm/src/stdlib/warnings.rs @@ -1,9 +1,9 @@ pub(crate) use _warnings::make_module; -use crate::{builtins::PyTypeRef, PyResult, VirtualMachine}; +use crate::{builtins::PyType, Py, PyResult, VirtualMachine}; pub fn warn( - category: PyTypeRef, + category: &Py, message: String, stack_level: usize, vm: &VirtualMachine, @@ -14,7 +14,7 @@ pub fn warn( // TODO if let Ok(module) = vm.import("warnings", None, 0) { if let Ok(func) = module.get_attr("warn", vm) { - let _ = vm.invoke(&func, (message, category, stack_level)); + let _ = vm.invoke(&func, (message, category.to_owned(), stack_level)); } } Ok(()) diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 3811f0c972..0ca4343984 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,6 +1,6 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; use crate::{ - builtins::{PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, + builtins::{PyFloat, PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, bytecode::ComparisonOperator, convert::{ToPyObject, ToPyResult}, function::Either, @@ -340,29 +340,41 @@ fn as_sequence_generic(zelf: &PyObject, vm: &VirtualMachine) -> &'static PySeque static_as_sequence_generic(has_length, has_ass_item) } -fn as_number_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { +fn as_number_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { Cow::Owned(PyNumberMethods { - int: then_some_closure!(zelf.class().has_attr("__int__"), |num, vm| { - let ret = vm.call_special_method(num.obj.to_owned(), "__int__", ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) - }) - }), - float: then_some_closure!(zelf.class().has_attr("__float__"), |num, vm| { - let ret = vm.call_special_method(num.obj.to_owned(), "__float__", ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!( - "__float__ returned non-float (type {})", - obj.class() - )) - }) - }), - index: then_some_closure!(zelf.class().has_attr("__index__"), |num, vm| { - let ret = vm.call_special_method(num.obj.to_owned(), "__index__", ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) - }) - }), + int: then_some_closure!( + zelf.class().has_attr(identifier!(vm, __int__)), + |num, vm| { + let ret = + vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) + }) + } + ), + float: then_some_closure!( + zelf.class().has_attr(identifier!(vm, __float__)), + |num, vm| { + let ret = + vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "__float__ returned non-float (type {})", + obj.class() + )) + }) + } + ), + index: then_some_closure!( + zelf.class().has_attr(identifier!(vm, __index__)), + |num, vm| { + let ret = + vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) + }) + } + ), ..*PyNumberMethods::not_implemented() }) } @@ -1035,25 +1047,7 @@ pub trait AsNumber: PyPayload { fn as_number(zelf: &Py, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods>; fn number_downcast<'a>(number: &'a PyNumber) -> &'a Py { - unsafe { number.obj.downcast_unchecked_ref::() } - } - - fn downcast_or_binop_error<'a, 'b>( - a: &'a PyNumber, - b: &'b PyObject, - op: &str, - vm: &VirtualMachine, - ) -> PyResult<(&'a Self, &'b Self)> { - if let (Some(a), Some(b)) = (a.obj.payload::(), b.payload::()) { - Ok((a, b)) - } else { - Err(vm.new_type_error(format!( - "unsupported operand type(s) for {}: '{}' and '{}'", - op, - a.obj.class(), - b.class() - ))) - } + unsafe { number.obj.downcast_unchecked_ref() } } } From 1ed18c012af12d0c9dd990ade8d42eb9dee74e56 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 22 May 2022 21:00:18 +0200 Subject: [PATCH 5/7] fix vm.to_index now use number protocol --- Lib/test/test_index.py | 2 -- vm/src/protocol/number.rs | 32 +++++++++++++++++++++++--------- vm/src/vm/vm_ops.rs | 24 +++--------------------- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/Lib/test/test_index.py b/Lib/test/test_index.py index 1fac132595..cbdc56c801 100644 --- a/Lib/test/test_index.py +++ b/Lib/test/test_index.py @@ -71,8 +71,6 @@ def __index__(self): self.assertIs(type(direct_index), int) #self.assertIs(type(operator_index), int) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_index_returns_int_subclass(self): class BadInt: def __index__(self): diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 68b575c55d..2868ca8bd3 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -134,9 +134,11 @@ impl PyNumber<'_> { ), 1, vm, - )? + )?; + Ok(vm.ctx.new_int(ret.as_bigint().clone())) + } else { + Ok(ret) } - Ok(ret) } else if self.methods(vm).index.is_some() { self.index(vm) } else if let Ok(Ok(f)) = @@ -173,9 +175,13 @@ impl PyNumber<'_> { } } - pub fn index(&self, vm: &VirtualMachine) -> PyResult { + pub fn index_opt(&self, vm: &VirtualMachine) -> PyResult> { if self.obj.class().is(PyInt::class(vm)) { - Ok(unsafe { self.obj.to_owned().downcast_unchecked::() }) + Ok(Some(unsafe { + self.obj.to_owned().downcast_unchecked::() + })) + } else if let Some(i) = self.obj.downcast_ref::() { + Ok(Some(i.to_owned())) } else if let Some(f) = self.methods(vm).index { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { @@ -189,15 +195,23 @@ impl PyNumber<'_> { ), 1, vm, - )? + )?; + Ok(Some(vm.ctx.new_int(ret.as_bigint().clone()))) + } else { + Ok(Some(ret)) } - Ok(ret) } else { - Err(vm.new_type_error(format!( + Ok(None) + } + } + + pub fn index(&self, vm: &VirtualMachine) -> PyResult { + self.index_opt(vm)?.ok_or_else(|| { + vm.new_type_error(format!( "'{}' object cannot be interpreted as an integer", self.obj.class() - ))) - } + )) + }) } pub fn float_opt(&self, vm: &VirtualMachine) -> PyResult>> { diff --git a/vm/src/vm/vm_ops.rs b/vm/src/vm/vm_ops.rs index e0606aefe7..1c325595cd 100644 --- a/vm/src/vm/vm_ops.rs +++ b/vm/src/vm/vm_ops.rs @@ -3,36 +3,18 @@ use crate::{ builtins::{PyInt, PyIntRef, PyStrInterned}, function::PyArithmeticValue, object::{AsObject, PyObject, PyObjectRef, PyResult}, - protocol::PyIterReturn, + protocol::{PyIterReturn, PyNumber}, types::PyComparisonOp, }; /// Collection of operators impl VirtualMachine { pub fn to_index_opt(&self, obj: PyObjectRef) -> Option> { - match obj.downcast() { - Ok(val) => Some(Ok(val)), - Err(obj) => self - .get_method(obj, identifier!(self, __index__)) - .map(|index| { - // TODO: returning strict subclasses of int in __index__ is deprecated - self.invoke(&index?, ())?.downcast().map_err(|bad| { - self.new_type_error(format!( - "__index__ returned non-int (type {})", - bad.class().name() - )) - }) - }), - } + PyNumber::from(obj.as_ref()).index_opt(self).transpose() } pub fn to_index(&self, obj: &PyObject) -> PyResult { - self.to_index_opt(obj.to_owned()).unwrap_or_else(|| { - Err(self.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - obj.class().name() - ))) - }) + PyNumber::from(obj).index(self) } #[inline] From ea95777ec72bc8c22f9bb7cf7cdadcb1345d46d2 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 29 May 2022 16:09:56 +0900 Subject: [PATCH 6/7] Simplify AsNumber trait --- vm/src/builtins/float.rs | 81 ++++++++++++++-------------- vm/src/builtins/int.rs | 92 +++++++++++++++----------------- vm/src/protocol/number.rs | 108 ++++++++++++++++++-------------------- vm/src/stdlib/warnings.rs | 5 +- vm/src/types/slot.rs | 11 ++-- 5 files changed, 138 insertions(+), 159 deletions(-) diff --git a/vm/src/builtins/float.rs b/vm/src/builtins/float.rs index 7712921170..cfa2e7e250 100644 --- a/vm/src/builtins/float.rs +++ b/vm/src/builtins/float.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use super::{ try_bigint_to_f64, PyByteArray, PyBytes, PyInt, PyIntRef, PyStr, PyStrRef, PyType, PyTypeRef, }; @@ -153,9 +151,7 @@ impl Constructor for PyFloat { let float_val = match arg { OptionalArg::Missing => 0.0, OptionalArg::Present(val) => { - if cls.is(vm.ctx.types.float_type) && val.class().is(PyFloat::class(vm)) { - unsafe { val.downcast_unchecked::().value } - } else if let Some(f) = val.try_float_opt(vm)? { + if let Some(f) = val.try_float_opt(vm)? { f.value } else { float_from_string(val, vm)? @@ -539,13 +535,40 @@ impl Hashable for PyFloat { } impl AsNumber for PyFloat { - fn as_number(_zelf: &crate::Py, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { - Cow::Borrowed(&Self::NUMBER_METHODS) - } + const AS_NUMBER: PyNumberMethods = PyNumberMethods { + add: Some(|number, other, vm| Self::number_float_op(number, other, |a, b| a + b, vm)), + subtract: Some(|number, other, vm| Self::number_float_op(number, other, |a, b| a - b, vm)), + multiply: Some(|number, other, vm| Self::number_float_op(number, other, |a, b| a * b, vm)), + remainder: Some(|number, other, vm| Self::number_general_op(number, other, inner_mod, vm)), + divmod: Some(|number, other, vm| Self::number_general_op(number, other, inner_divmod, vm)), + power: Some(|number, other, vm| Self::number_general_op(number, other, float_pow, vm)), + negative: Some(|number, vm| { + let value = Self::number_downcast(number).value; + (-value).to_pyresult(vm) + }), + positive: Some(|number, vm| Self::number_float(number, vm).to_pyresult(vm)), + absolute: Some(|number, vm| { + let value = Self::number_downcast(number).value; + value.abs().to_pyresult(vm) + }), + boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), + int: Some(|number, vm| { + let value = Self::number_downcast(number).value; + try_to_bigint(value, vm).map(|x| vm.ctx.new_int(x)) + }), + float: Some(|number, vm| Ok(Self::number_float(number, vm))), + floor_divide: Some(|number, other, vm| { + Self::number_general_op(number, other, inner_floordiv, vm) + }), + true_divide: Some(|number, other, vm| { + Self::number_general_op(number, other, inner_div, vm) + }), + ..PyNumberMethods::NOT_IMPLEMENTED + }; } impl PyFloat { - fn np_general_op( + fn number_general_op( number: &PyNumber, other: &PyObject, op: F, @@ -562,49 +585,25 @@ impl PyFloat { } } - fn np_float_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult + fn number_float_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult where F: FnOnce(f64, f64) -> f64, { - Self::np_general_op(number, other, |a, b, _vm| op(a, b), vm) + Self::number_general_op(number, other, |a, b, _vm| op(a, b), vm) } - fn np_float(number: &PyNumber, vm: &VirtualMachine) -> PyRef { + fn number_float(number: &PyNumber, vm: &VirtualMachine) -> PyRef { if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { zelf.to_owned() } else { vm.ctx.new_float(Self::number_downcast(number).value) } } - - const NUMBER_METHODS: PyNumberMethods = PyNumberMethods { - add: Some(|number, other, vm| Self::np_float_op(number, other, |a, b| a + b, vm)), - subtract: Some(|number, other, vm| Self::np_float_op(number, other, |a, b| a - b, vm)), - multiply: Some(|number, other, vm| Self::np_float_op(number, other, |a, b| a * b, vm)), - remainder: Some(|number, other, vm| Self::np_general_op(number, other, inner_mod, vm)), - divmod: Some(|number, other, vm| Self::np_general_op(number, other, inner_divmod, vm)), - power: Some(|number, other, vm| Self::np_general_op(number, other, float_pow, vm)), - negative: Some(|number, vm| { - let value = Self::number_downcast(number).value; - (-value).to_pyresult(vm) - }), - positive: Some(|number, vm| Self::np_float(number, vm).to_pyresult(vm)), - absolute: Some(|number, vm| { - let value = Self::number_downcast(number).value; - value.abs().to_pyresult(vm) - }), - boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), - int: Some(|number, vm| { - let value = Self::number_downcast(number).value; - try_to_bigint(value, vm).map(|x| vm.ctx.new_int(x)) - }), - float: Some(|number, vm| Ok(Self::np_float(number, vm))), - floor_divide: Some(|number, other, vm| { - Self::np_general_op(number, other, inner_floordiv, vm) - }), - true_divide: Some(|number, other, vm| Self::np_general_op(number, other, inner_div, vm)), - ..*PyNumberMethods::not_implemented() - }; } // Retrieve inner float value: diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index c4baea3a2f..37405831c7 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -18,8 +18,8 @@ use bstr::ByteSlice; use num_bigint::{BigInt, BigUint, Sign}; use num_integer::Integer; use num_traits::{One, Pow, PrimInt, Signed, ToPrimitive, Zero}; -use std::fmt; -use std::{borrow::Cow, ops::Neg}; +use std::ops::Neg; +use std::{fmt, ops::Not}; /// int(x=0) -> integer /// int(x, base=10) -> integer @@ -756,13 +756,46 @@ impl Hashable for PyInt { } impl AsNumber for PyInt { - fn as_number(_zelf: &crate::Py, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { - Cow::Borrowed(&Self::NUMBER_METHODS) - } + const AS_NUMBER: PyNumberMethods = PyNumberMethods { + add: Some(|number, other, vm| Self::number_int_op(number, other, |a, b| a + b, vm)), + subtract: Some(|number, other, vm| Self::number_int_op(number, other, |a, b| a - b, vm)), + multiply: Some(|number, other, vm| Self::number_int_op(number, other, |a, b| a * b, vm)), + remainder: Some(|number, other, vm| Self::number_general_op(number, other, inner_mod, vm)), + divmod: Some(|number, other, vm| Self::number_general_op(number, other, inner_divmod, vm)), + power: Some(|number, other, vm| Self::number_general_op(number, other, inner_pow, vm)), + negative: Some(|number, vm| (&Self::number_downcast(number).value).neg().to_pyresult(vm)), + positive: Some(|number, vm| Ok(Self::number_int(number, vm).into())), + absolute: Some(|number, vm| Self::number_downcast(number).value.abs().to_pyresult(vm)), + boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), + invert: Some(|number, vm| (&Self::number_downcast(number).value).not().to_pyresult(vm)), + lshift: Some(|number, other, vm| Self::number_general_op(number, other, inner_lshift, vm)), + rshift: Some(|number, other, vm| Self::number_general_op(number, other, inner_rshift, vm)), + and: Some(|number, other, vm| Self::number_int_op(number, other, |a, b| a & b, vm)), + xor: Some(|number, other, vm| Self::number_int_op(number, other, |a, b| a ^ b, vm)), + or: Some(|number, other, vm| Self::number_int_op(number, other, |a, b| a | b, vm)), + int: Some(|number, other| Ok(Self::number_int(number, other))), + float: Some(|number, vm| { + let zelf = Self::number_downcast(number); + try_to_float(&zelf.value, vm).map(|x| vm.ctx.new_float(x)) + }), + floor_divide: Some(|number, other, vm| { + Self::number_general_op(number, other, inner_floordiv, vm) + }), + true_divide: Some(|number, other, vm| { + Self::number_general_op(number, other, inner_truediv, vm) + }), + index: Some(|number, vm| Ok(Self::number_int(number, vm))), + ..PyNumberMethods::NOT_IMPLEMENTED + }; } impl PyInt { - fn np_general_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult + fn number_general_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult where F: FnOnce(&BigInt, &BigInt, &VirtualMachine) -> PyResult, { @@ -773,14 +806,14 @@ impl PyInt { } } - fn np_int_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult + fn number_int_op(number: &PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult where F: FnOnce(&BigInt, &BigInt) -> BigInt, { - Self::np_general_op(number, other, |a, b, _vm| op(a, b).to_pyresult(vm), vm) + Self::number_general_op(number, other, |a, b, _vm| op(a, b).to_pyresult(vm), vm) } - fn np_int(number: &PyNumber, vm: &VirtualMachine) -> PyIntRef { + fn number_int(number: &PyNumber, vm: &VirtualMachine) -> PyIntRef { if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { zelf.to_owned() } else { @@ -788,47 +821,6 @@ impl PyInt { vm.ctx.new_int(zelf.value.clone()) } } - - const NUMBER_METHODS: PyNumberMethods = PyNumberMethods { - add: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a + b, vm)), - subtract: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a - b, vm)), - multiply: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a * b, vm)), - remainder: Some(|number, other, vm| Self::np_general_op(number, other, inner_mod, vm)), - divmod: Some(|number, other, vm| Self::np_general_op(number, other, inner_divmod, vm)), - power: Some(|number, other, vm| Self::np_general_op(number, other, inner_pow, vm)), - negative: Some(|number, vm| { - Self::number_downcast(number) - .value - .clone() - .neg() - .to_pyresult(vm) - }), - positive: Some(|number, vm| Ok(Self::np_int(number, vm).into())), - absolute: Some(|number, vm| Self::number_downcast(number).value.abs().to_pyresult(vm)), - boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), - invert: Some(|number, vm| { - let value = Self::number_downcast(number).value.clone(); - (!value).to_pyresult(vm) - }), - lshift: Some(|number, other, vm| Self::np_general_op(number, other, inner_lshift, vm)), - rshift: Some(|number, other, vm| Self::np_general_op(number, other, inner_rshift, vm)), - and: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a & b, vm)), - xor: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a ^ b, vm)), - or: Some(|number, other, vm| Self::np_int_op(number, other, |a, b| a | b, vm)), - int: Some(|number, other| Ok(Self::np_int(number, other))), - float: Some(|number, vm| { - let zelf = Self::number_downcast(number); - try_to_float(&zelf.value, vm).map(|x| vm.ctx.new_float(x)) - }), - floor_divide: Some(|number, other, vm| { - Self::np_general_op(number, other, inner_floordiv, vm) - }), - true_divide: Some(|number, other, vm| { - Self::np_general_op(number, other, inner_truediv, vm) - }), - index: Some(|number, vm| Ok(Self::np_int(number, vm))), - ..*PyNumberMethods::not_implemented() - }; } #[derive(FromArgs)] diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 2868ca8bd3..046ba8faed 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -34,7 +34,7 @@ pub struct PyNumberMethods { pub float: Option PyResult>>, pub inplace_add: Option PyResult>, - pub inplace_substract: Option PyResult>, + pub inplace_subtract: Option PyResult>, pub inplace_multiply: Option PyResult>, pub inplace_remainder: Option PyResult>, pub inplace_divmod: Option PyResult>, @@ -48,7 +48,7 @@ pub struct PyNumberMethods { pub floor_divide: Option PyResult>, pub true_divide: Option PyResult>, pub inplace_floor_divide: Option PyResult>, - pub inplace_true_devide: Option PyResult>, + pub inplace_true_divide: Option PyResult>, pub index: Option PyResult>, @@ -57,9 +57,44 @@ pub struct PyNumberMethods { } impl PyNumberMethods { - pub const fn not_implemented() -> &'static Self { - &NOT_IMPLEMENTED - } + pub const NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods { + add: None, + subtract: None, + multiply: None, + remainder: None, + divmod: None, + power: None, + negative: None, + positive: None, + absolute: None, + boolean: None, + invert: None, + lshift: None, + rshift: None, + and: None, + xor: None, + or: None, + int: None, + float: None, + inplace_add: None, + inplace_subtract: None, + inplace_multiply: None, + inplace_remainder: None, + inplace_divmod: None, + inplace_power: None, + inplace_lshift: None, + inplace_rshift: None, + inplace_and: None, + inplace_xor: None, + inplace_or: None, + floor_divide: None, + true_divide: None, + inplace_floor_divide: None, + inplace_true_divide: None, + index: None, + matrix_multiply: None, + inplace_matrix_multiply: None, + }; } pub struct PyNumber<'a> { @@ -88,7 +123,7 @@ impl PyNumber<'_> { .class() .mro_find_map(|x| x.slots.as_number.load()) .map(|f| f(self.obj, vm)) - .unwrap_or_else(|| Cow::Borrowed(PyNumberMethods::not_implemented())) + .unwrap_or_else(|| Cow::Borrowed(&PyNumberMethods::NOT_IMPLEMENTED)) }) } @@ -119,8 +154,8 @@ impl PyNumber<'_> { } } - if self.obj.class().is(PyInt::class(vm)) { - Ok(unsafe { self.obj.to_owned().downcast_unchecked::() }) + if let Some(i) = self.obj.downcast_ref_if_exact::(vm) { + Ok(i.to_owned()) } else if let Some(f) = self.methods(vm).int { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { @@ -135,7 +170,7 @@ impl PyNumber<'_> { 1, vm, )?; - Ok(vm.ctx.new_int(ret.as_bigint().clone())) + Ok(vm.ctx.new_bigint(ret.as_bigint())) } else { Ok(ret) } @@ -176,12 +211,10 @@ impl PyNumber<'_> { } pub fn index_opt(&self, vm: &VirtualMachine) -> PyResult> { - if self.obj.class().is(PyInt::class(vm)) { - Ok(Some(unsafe { - self.obj.to_owned().downcast_unchecked::() - })) - } else if let Some(i) = self.obj.downcast_ref::() { + if let Some(i) = self.obj.downcast_ref_if_exact::(vm) { Ok(Some(i.to_owned())) + } else if let Some(i) = self.obj.payload::() { + Ok(Some(vm.ctx.new_bigint(i.as_bigint()))) } else if let Some(f) = self.methods(vm).index { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { @@ -196,7 +229,7 @@ impl PyNumber<'_> { 1, vm, )?; - Ok(Some(vm.ctx.new_int(ret.as_bigint().clone()))) + Ok(Some(vm.ctx.new_bigint(ret.as_bigint()))) } else { Ok(Some(ret)) } @@ -215,10 +248,8 @@ impl PyNumber<'_> { } pub fn float_opt(&self, vm: &VirtualMachine) -> PyResult>> { - if self.obj.class().is(PyFloat::class(vm)) { - Ok(Some(unsafe { - self.obj.to_owned().downcast_unchecked::() - })) + if let Some(float) = self.obj.downcast_ref_if_exact::(vm) { + Ok(Some(float.to_owned())) } else if let Some(f) = self.methods(vm).float { let ret = f(self, vm)?; if !ret.class().is(PyFloat::class(vm)) { @@ -254,42 +285,3 @@ impl PyNumber<'_> { }) } } - -const NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods { - add: None, - subtract: None, - multiply: None, - remainder: None, - divmod: None, - power: None, - negative: None, - positive: None, - absolute: None, - boolean: None, - invert: None, - lshift: None, - rshift: None, - and: None, - xor: None, - or: None, - int: None, - float: None, - inplace_add: None, - inplace_substract: None, - inplace_multiply: None, - inplace_remainder: None, - inplace_divmod: None, - inplace_power: None, - inplace_lshift: None, - inplace_rshift: None, - inplace_and: None, - inplace_xor: None, - inplace_or: None, - floor_divide: None, - true_divide: None, - inplace_floor_divide: None, - inplace_true_devide: None, - index: None, - matrix_multiply: None, - inplace_matrix_multiply: None, -}; diff --git a/vm/src/stdlib/warnings.rs b/vm/src/stdlib/warnings.rs index 4bf14ef228..7eca878a46 100644 --- a/vm/src/stdlib/warnings.rs +++ b/vm/src/stdlib/warnings.rs @@ -8,10 +8,7 @@ pub fn warn( stack_level: usize, vm: &VirtualMachine, ) -> PyResult<()> { - // let module = vm.import("warnings", None, 0)?; - // let func = module.get_attr("warn", vm)?; - // vm.invoke(&func, (message, category, stack_level))?; - // TODO + // TODO: use rust warnings module if let Ok(module) = vm.import("warnings", None, 0) { if let Ok(func) = module.get_attr("warn", vm) { let _ = vm.invoke(&func, (message, category.to_owned(), stack_level)); diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 0ca4343984..38ad646548 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -375,7 +375,7 @@ fn as_number_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> Cow<'static, PyNum }) } ), - ..*PyNumberMethods::not_implemented() + ..PyNumberMethods::NOT_IMPLEMENTED }) } @@ -1037,15 +1037,14 @@ pub trait AsSequence: PyPayload { #[pyimpl] pub trait AsNumber: PyPayload { + const AS_NUMBER: PyNumberMethods; + #[inline] #[pyslot] - fn slot_as_number(zelf: &PyObject, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { - let zelf = unsafe { zelf.downcast_unchecked_ref::() }; - Self::as_number(zelf, vm) + fn as_number(_zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { + Cow::Borrowed(&Self::AS_NUMBER) } - fn as_number(zelf: &Py, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods>; - fn number_downcast<'a>(number: &'a PyNumber) -> &'a Py { unsafe { number.obj.downcast_unchecked_ref() } } From 09fc67616436b62110c13d689bd005ef218751ec Mon Sep 17 00:00:00 2001 From: Jeong Yunwon Date: Mon, 30 May 2022 06:54:56 +0900 Subject: [PATCH 7/7] Use static ref --- vm/src/builtins/float.rs | 4 +- vm/src/builtins/int.rs | 2 +- vm/src/function/number.rs | 2 +- vm/src/protocol/number.rs | 56 ++++++++++----------- vm/src/types/slot.rs | 102 +++++++++++++++++++++++--------------- vm/src/vm/vm_ops.rs | 6 ++- 6 files changed, 95 insertions(+), 77 deletions(-) diff --git a/vm/src/builtins/float.rs b/vm/src/builtins/float.rs index cfa2e7e250..d83439b5eb 100644 --- a/vm/src/builtins/float.rs +++ b/vm/src/builtins/float.rs @@ -59,11 +59,11 @@ impl From for PyFloat { impl PyObject { pub fn try_float_opt(&self, vm: &VirtualMachine) -> PyResult>> { - PyNumber::from(self).float_opt(vm) + PyNumber::new(self, vm).float_opt(vm) } pub fn try_float(&self, vm: &VirtualMachine) -> PyResult> { - PyNumber::from(self).float(vm) + PyNumber::new(self, vm).float(vm) } } diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index 37405831c7..9a1ad9cfab 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -265,7 +265,7 @@ impl Constructor for PyInt { val }; - PyNumber::from(val.as_ref()) + PyNumber::new(val.as_ref(), vm) .int(vm) .map(|x| x.as_bigint().clone()) } diff --git a/vm/src/function/number.rs b/vm/src/function/number.rs index 62922c85ba..4e8219c342 100644 --- a/vm/src/function/number.rs +++ b/vm/src/function/number.rs @@ -82,7 +82,7 @@ impl Deref for ArgIntoFloat { impl TryFromObject for ArgIntoFloat { // Equivalent to PyFloat_AsDouble. fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let value = PyNumber::from(obj.as_ref()).float(vm)?.to_f64(); + let value = PyNumber::new(obj.as_ref(), vm).float(vm)?.to_f64(); Ok(ArgIntoFloat { value }) } } diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 046ba8faed..f2420f709b 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -1,8 +1,5 @@ -use std::borrow::Cow; - use crate::{ builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr}, - common::lock::OnceCell, function::ArgBytesLike, stdlib::warnings, AsObject, PyObject, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, @@ -100,45 +97,42 @@ impl PyNumberMethods { pub struct PyNumber<'a> { pub obj: &'a PyObject, // some fast path do not need methods, so we do lazy initialize - methods: OnceCell>, + pub methods: Option<&'static PyNumberMethods>, } -impl<'a> From<&'a PyObject> for PyNumber<'a> { - fn from(obj: &'a PyObject) -> Self { +impl<'a> PyNumber<'a> { + pub fn new(obj: &'a PyObject, vm: &VirtualMachine) -> Self { Self { obj, - methods: OnceCell::new(), + methods: Self::find_methods(obj, vm), } } } impl PyNumber<'_> { - pub fn methods(&self, vm: &VirtualMachine) -> &PyNumberMethods { - &*self.methods_cow(vm) + pub fn find_methods(obj: &PyObject, vm: &VirtualMachine) -> Option<&'static PyNumberMethods> { + obj.class() + .mro_find_map(|x| x.slots.as_number.load()) + .map(|f| f(obj, vm)) } - pub fn methods_cow(&self, vm: &VirtualMachine) -> &Cow<'static, PyNumberMethods> { - self.methods.get_or_init(|| { - self.obj - .class() - .mro_find_map(|x| x.slots.as_number.load()) - .map(|f| f(self.obj, vm)) - .unwrap_or_else(|| Cow::Borrowed(&PyNumberMethods::NOT_IMPLEMENTED)) - }) + pub fn methods(&self) -> &'static PyNumberMethods { + self.methods.unwrap_or(&PyNumberMethods::NOT_IMPLEMENTED) } // PyNumber_Check - pub fn check(&self, vm: &VirtualMachine) -> bool { - let methods = self.methods(vm); - methods.int.is_some() - || methods.index.is_some() - || methods.float.is_some() - || self.obj.payload_is::() + pub fn check(obj: &PyObject, vm: &VirtualMachine) -> bool { + Self::find_methods(obj, vm).map_or(false, |methods| { + methods.int.is_some() + || methods.index.is_some() + || methods.float.is_some() + || obj.payload_is::() + }) } // PyIndex_Check - pub fn is_index(&self, vm: &VirtualMachine) -> bool { - self.methods(vm).index.is_some() + pub fn is_index(&self) -> bool { + self.methods().index.is_some() } pub fn int(&self, vm: &VirtualMachine) -> PyResult { @@ -156,7 +150,7 @@ impl PyNumber<'_> { if let Some(i) = self.obj.downcast_ref_if_exact::(vm) { Ok(i.to_owned()) - } else if let Some(f) = self.methods(vm).int { + } else if let Some(f) = self.methods().int { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { warnings::warn( @@ -174,7 +168,7 @@ impl PyNumber<'_> { } else { Ok(ret) } - } else if self.methods(vm).index.is_some() { + } else if self.methods().index.is_some() { self.index(vm) } else if let Ok(Ok(f)) = vm.get_special_method(self.obj.to_owned(), identifier!(vm, __trunc__)) @@ -187,7 +181,7 @@ impl PyNumber<'_> { // vm, // )?; let ret = f.invoke((), vm)?; - PyNumber::from(ret.as_ref()).index(vm).map_err(|_| { + PyNumber::new(ret.as_ref(), vm).index(vm).map_err(|_| { vm.new_type_error(format!( "__trunc__ returned non-Integral (type {})", ret.class() @@ -215,7 +209,7 @@ impl PyNumber<'_> { Ok(Some(i.to_owned())) } else if let Some(i) = self.obj.payload::() { Ok(Some(vm.ctx.new_bigint(i.as_bigint()))) - } else if let Some(f) = self.methods(vm).index { + } else if let Some(f) = self.methods().index { let ret = f(self, vm)?; if !ret.class().is(PyInt::class(vm)) { warnings::warn( @@ -250,7 +244,7 @@ impl PyNumber<'_> { pub fn float_opt(&self, vm: &VirtualMachine) -> PyResult>> { if let Some(float) = self.obj.downcast_ref_if_exact::(vm) { Ok(Some(float.to_owned())) - } else if let Some(f) = self.methods(vm).float { + } else if let Some(f) = self.methods().float { let ret = f(self, vm)?; if !ret.class().is(PyFloat::class(vm)) { warnings::warn( @@ -268,7 +262,7 @@ impl PyNumber<'_> { } else { Ok(Some(ret)) } - } else if self.methods(vm).index.is_some() { + } else if self.methods().index.is_some() { let i = self.index(vm)?; let value = int::try_to_float(i.as_bigint(), vm)?; Ok(Some(vm.ctx.new_float(value))) diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 38ad646548..ae3099a284 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -139,7 +139,7 @@ impl Default for PyTypeFlags { pub(crate) type GenericMethod = fn(&PyObject, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type AsMappingFunc = fn(&PyObject, &VirtualMachine) -> &'static PyMappingMethods; -pub(crate) type AsNumberFunc = fn(&PyObject, &VirtualMachine) -> Cow<'static, PyNumberMethods>; +pub(crate) type AsNumberFunc = fn(&PyObject, &VirtualMachine) -> &'static PyNumberMethods; pub(crate) type HashFunc = fn(&PyObject, &VirtualMachine) -> PyResult; // CallFunc = GenericMethod pub(crate) type GetattroFunc = fn(&PyObject, PyStrRef, &VirtualMachine) -> PyResult; @@ -340,43 +340,65 @@ fn as_sequence_generic(zelf: &PyObject, vm: &VirtualMachine) -> &'static PySeque static_as_sequence_generic(has_length, has_ass_item) } -fn as_number_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { - Cow::Owned(PyNumberMethods { - int: then_some_closure!( - zelf.class().has_attr(identifier!(vm, __int__)), - |num, vm| { - let ret = - vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) - }) - } - ), - float: then_some_closure!( - zelf.class().has_attr(identifier!(vm, __float__)), - |num, vm| { - let ret = - vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!( - "__float__ returned non-float (type {})", - obj.class() - )) - }) - } - ), - index: then_some_closure!( - zelf.class().has_attr(identifier!(vm, __index__)), - |num, vm| { - let ret = - vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) - }) - } - ), - ..PyNumberMethods::NOT_IMPLEMENTED - }) +pub(crate) fn static_as_number_generic( + has_int: bool, + has_float: bool, + has_index: bool, +) -> &'static PyNumberMethods { + static METHODS: &[PyNumberMethods] = &[ + new_generic(false, false, false), + new_generic(true, false, false), + new_generic(false, true, false), + new_generic(true, true, false), + new_generic(false, false, true), + new_generic(true, false, true), + new_generic(false, true, true), + new_generic(true, true, true), + ]; + + fn int(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) + }) + } + fn float(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "__float__ returned non-float (type {})", + obj.class() + )) + }) + } + fn index(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) + }) + } + + const fn new_generic(has_int: bool, has_float: bool, has_index: bool) -> PyNumberMethods { + PyNumberMethods { + int: if has_int { Some(int) } else { None }, + float: if has_float { Some(float) } else { None }, + index: if has_index { Some(index) } else { None }, + ..PyNumberMethods::NOT_IMPLEMENTED + } + } + + let key = bool_int(has_int) | (bool_int(has_float) << 1) | (bool_int(has_index) << 2); + + &METHODS[key] +} + +fn as_number_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> &'static PyNumberMethods { + let (has_int, has_float, has_index) = ( + zelf.class().has_attr(identifier!(vm, __int__)), + zelf.class().has_attr(identifier!(vm, __float__)), + zelf.class().has_attr(identifier!(vm, __index__)), + ); + static_as_number_generic(has_int, has_float, has_index) } fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { @@ -1041,8 +1063,8 @@ pub trait AsNumber: PyPayload { #[inline] #[pyslot] - fn as_number(_zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PyNumberMethods> { - Cow::Borrowed(&Self::AS_NUMBER) + fn as_number(_zelf: &PyObject, _vm: &VirtualMachine) -> &'static PyNumberMethods { + &Self::AS_NUMBER } fn number_downcast<'a>(number: &'a PyNumber) -> &'a Py { diff --git a/vm/src/vm/vm_ops.rs b/vm/src/vm/vm_ops.rs index 1c325595cd..f8af60b7a0 100644 --- a/vm/src/vm/vm_ops.rs +++ b/vm/src/vm/vm_ops.rs @@ -10,11 +10,13 @@ use crate::{ /// Collection of operators impl VirtualMachine { pub fn to_index_opt(&self, obj: PyObjectRef) -> Option> { - PyNumber::from(obj.as_ref()).index_opt(self).transpose() + PyNumber::new(obj.as_ref(), self) + .index_opt(self) + .transpose() } pub fn to_index(&self, obj: &PyObject) -> PyResult { - PyNumber::from(obj).index(self) + PyNumber::new(obj, self).index(self) } #[inline]