Skip to content

Improve: binaryops with number protocol #4139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vm/src/builtins/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ impl PyBytes {
}

#[pymethod(magic)]
fn rmod(&self, _values: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
vm.ctx.not_implemented()
fn rmod(&self, _values: PyObjectRef, vm: &VirtualMachine) -> PyResult {
Ok(vm.ctx.not_implemented())
}

/// Return a string decoded from the given bytes.
Expand Down
2 changes: 1 addition & 1 deletion vm/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ mod sequence;
pub use buffer::{BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, VecBuffer};
pub use iter::{PyIter, PyIterIter, PyIterReturn};
pub use mapping::{PyMapping, PyMappingMethods};
pub use number::{PyNumber, PyNumberMethods};
pub use number::{PyNumber, PyNumberMethods, PyNumberMethodsOffset};
pub use sequence::{PySequence, PySequenceMethods};
110 changes: 104 additions & 6 deletions vm/src/protocol/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
VirtualMachine,
};
use crossbeam_utils::atomic::AtomicCell;
use std::ptr;

type UnaryFunc<R = PyObjectRef> = AtomicCell<Option<fn(PyNumber, &VirtualMachine) -> PyResult<R>>>;
type BinaryFunc<R = PyObjectRef> =
Expand Down Expand Up @@ -109,6 +110,7 @@ impl PyObject {
}

#[derive(Default)]
// #[repr(C)]
pub struct PyNumberMethods {
/* Number implementations must check *both*
arguments for proper type and implement the necessary conversions
Expand Down Expand Up @@ -199,6 +201,98 @@ impl PyNumberMethods {
};
}

pub enum PyNumberMethodsOffset {
Add,
Subtract,
Multiply,
Remainder,
Divmod,
Power,
Negative,
Positive,
Absolute,
Boolean,
Invert,
Lshift,
Rshift,
And,
Xor,
Or,
Int,
Float,
InplaceAdd,
InplaceSubtract,
InplaceMultiply,
InplaceRemainder,
InplaceDivmod,
InplacePower,
InplaceLshift,
InplaceRshift,
InplaceAnd,
InplaceXor,
InplaceOr,
FloorDivide,
TrueDivide,
InplaceFloorDivide,
InplaceTrueDivide,
Index,
MatrixMultiply,
InplaceMatrixMultiply,
}

impl PyNumberMethodsOffset {
pub fn method(&self, methods: &PyNumberMethods, vm: &VirtualMachine) -> PyResult<&BinaryFunc> {
use PyNumberMethodsOffset::*;
unsafe {
match self {
// BinaryFunc
Add => ptr::addr_of!(methods.add),
Subtract => ptr::addr_of!(methods.subtract),
Multiply => ptr::addr_of!(methods.multiply),
Remainder => ptr::addr_of!(methods.remainder),
Divmod => ptr::addr_of!(methods.divmod),
Power => ptr::addr_of!(methods.power),
Lshift => ptr::addr_of!(methods.lshift),
Rshift => ptr::addr_of!(methods.rshift),
And => ptr::addr_of!(methods.and),
Xor => ptr::addr_of!(methods.xor),
Or => ptr::addr_of!(methods.or),
InplaceAdd => ptr::addr_of!(methods.inplace_add),
InplaceSubtract => ptr::addr_of!(methods.inplace_subtract),
InplaceMultiply => ptr::addr_of!(methods.inplace_multiply),
InplaceRemainder => ptr::addr_of!(methods.inplace_remainder),
InplaceDivmod => ptr::addr_of!(methods.inplace_divmod),
InplacePower => ptr::addr_of!(methods.inplace_power),
InplaceLshift => ptr::addr_of!(methods.inplace_lshift),
InplaceRshift => ptr::addr_of!(methods.inplace_rshift),
InplaceAnd => ptr::addr_of!(methods.inplace_and),
InplaceXor => ptr::addr_of!(methods.inplace_xor),
InplaceOr => ptr::addr_of!(methods.inplace_or),
FloorDivide => ptr::addr_of!(methods.floor_divide),
TrueDivide => ptr::addr_of!(methods.true_divide),
InplaceFloorDivide => ptr::addr_of!(methods.inplace_floor_divide),
InplaceTrueDivide => ptr::addr_of!(methods.inplace_true_divide),
MatrixMultiply => ptr::addr_of!(methods.matrix_multiply),
InplaceMatrixMultiply => ptr::addr_of!(methods.inplace_matrix_multiply),
// UnaryFunc
Negative => ptr::null(),
Positive => ptr::null(),
Absolute => ptr::null(),
Boolean => ptr::null(),
Invert => ptr::null(),
Int => ptr::null(),
Float => ptr::null(),
Index => ptr::null(),
}
.as_ref()
.ok_or_else(|| {
vm.new_value_error("No unaryop supported for PyNumberMethodsOffset".to_owned())
})
}
}
}

#[derive(Copy, Clone)]
pub struct PyNumber<'a> {
pub obj: &'a PyObject,
methods: &'a PyNumberMethods,
Expand All @@ -220,8 +314,12 @@ impl PyNumber<'_> {
obj.class().mro_find_map(|x| x.slots.as_number.load())
}

pub fn methods(&self) -> &PyNumberMethods {
self.methods
pub fn methods<'a>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub fn methods<'a>(
pub fn binary_op<'a>(

&'a self,
op_slot: &'a PyNumberMethodsOffset,
vm: &VirtualMachine,
) -> PyResult<&BinaryFunc> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we keep the name binary_op, the return type will be something other than BinaryFunc

op_slot.method(self.methods, vm)
}

// PyNumber_Check
Expand All @@ -238,12 +336,12 @@ impl PyNumber<'_> {

// PyIndex_Check
pub fn is_index(&self) -> bool {
self.methods().index.load().is_some()
self.methods.index.load().is_some()
}

#[inline]
pub fn int(self, vm: &VirtualMachine) -> Option<PyResult<PyIntRef>> {
self.methods().int.load().map(|f| {
self.methods.int.load().map(|f| {
let ret = f(self, vm)?;
let value = if !ret.class().is(PyInt::class(vm)) {
warnings::warn(
Expand All @@ -267,7 +365,7 @@ impl PyNumber<'_> {

#[inline]
pub fn index(self, vm: &VirtualMachine) -> Option<PyResult<PyIntRef>> {
self.methods().index.load().map(|f| {
self.methods.index.load().map(|f| {
let ret = f(self, vm)?;
let value = if !ret.class().is(PyInt::class(vm)) {
warnings::warn(
Expand All @@ -291,7 +389,7 @@ impl PyNumber<'_> {

#[inline]
pub fn float(self, vm: &VirtualMachine) -> Option<PyResult<PyRef<PyFloat>>> {
self.methods().float.load().map(|f| {
self.methods.float.load().map(|f| {
let ret = f(self, vm)?;
let value = if !ret.class().is(PyFloat::class(vm)) {
warnings::warn(
Expand Down
12 changes: 4 additions & 8 deletions vm/src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod builtins {
ArgBytesLike, ArgCallable, ArgIntoBool, ArgIterable, ArgMapping, ArgStrOrBytesLike,
Either, FuncArgs, KwArgs, OptionalArg, OptionalOption, PosArgs, PyArithmeticValue,
},
protocol::{PyIter, PyIterReturn},
protocol::{PyIter, PyIterReturn, PyNumberMethodsOffset},
py_io,
readline::{Readline, ReadlineResult},
stdlib::sys,
Expand Down Expand Up @@ -610,13 +610,9 @@ mod builtins {
modulus,
} = args;
match modulus {
None => vm.call_or_reflection(
&x,
&y,
identifier!(vm, __pow__),
identifier!(vm, __rpow__),
|vm, x, y| Err(vm.new_unsupported_binop_error(x, y, "pow")),
),
None => vm.binary_op(&x, &y, PyNumberMethodsOffset::Power, "pow", |vm, _, _| {
Ok(vm.ctx.not_implemented())
}),
Some(z) => {
let try_pow_value = |obj: &PyObject,
args: (PyObjectRef, PyObjectRef, PyObjectRef)|
Expand Down
Loading