Skip to content

Refactor Number Protocol BinaryOps & HeapTypeExt #4720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
7 changes: 6 additions & 1 deletion derive-impl/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,8 @@ where
let slot_name = slot_ident.to_string();
let tokens = {
const NON_ATOMIC_SLOTS: &[&str] = &["as_buffer"];
const POINTER_SLOTS: &[&str] = &["as_number", "as_sequence", "as_mapping"];
const POINTER_SLOTS: &[&str] = &["as_sequence", "as_mapping"];
const STATIC_SLOTS: &[&str] = &["as_number"];
if NON_ATOMIC_SLOTS.contains(&slot_name.as_str()) {
quote_spanned! { span =>
slots.#slot_ident = Some(Self::#ident as _);
Expand All @@ -711,6 +712,10 @@ where
quote_spanned! { span =>
slots.#slot_ident.store(Some(PointerSlot::from(Self::#ident())));
}
} else if STATIC_SLOTS.contains(&slot_name.as_str()) {
quote_spanned! { span =>
slots.#slot_ident = Self::#ident();
}
} else {
quote_spanned! { span =>
slots.#slot_ident.store(Some(Self::#ident as _));
Expand Down
3 changes: 1 addition & 2 deletions stdlib/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let array = module
.get_attr("array", vm)
.expect("Expect array has array type.");
array.init_builtin_number_slots(&vm.ctx);

let collections_abc = vm
.import("collections.abc", None, 0)
Expand Down Expand Up @@ -722,7 +721,7 @@ mod array {

#[pyclass(
flags(BASETYPE),
with(Comparable, AsBuffer, AsMapping, Iterable, Constructor)
with(Comparable, AsBuffer, AsMapping, AsSequence, Iterable, Constructor)
)]
impl PyArray {
fn read(&self) -> PyRwLockReadGuard<'_, ArrayContentType> {
Expand Down
12 changes: 3 additions & 9 deletions vm/src/builtins/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,9 @@ impl PyBool {
impl AsNumber for PyBool {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
and: Some(|number, other, vm| {
PyBool::and(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
}),
xor: Some(|number, other, vm| {
PyBool::xor(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
}),
or: Some(|number, other, vm| {
PyBool::or(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
}),
and: Some(|a, b, vm| PyBool::and(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
xor: Some(|a, b, vm| PyBool::xor(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
or: Some(|a, b, vm| PyBool::or(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
..PyInt::AS_NUMBER
};
&AS_NUMBER
Expand Down
6 changes: 3 additions & 3 deletions vm/src/builtins/bytearray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,9 @@ impl AsSequence for PyByteArray {
impl AsNumber for PyByteArray {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
remainder: Some(|number, other, vm| {
if let Some(number) = number.obj.downcast_ref::<PyByteArray>() {
number.mod_(other.to_owned(), vm).to_pyresult(vm)
remainder: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PyByteArray>() {
a.mod_(b.to_owned(), vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
Expand Down
6 changes: 3 additions & 3 deletions vm/src/builtins/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ impl AsSequence for PyBytes {
impl AsNumber for PyBytes {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
remainder: Some(|number, other, vm| {
if let Some(number) = number.obj.downcast_ref::<PyBytes>() {
number.mod_(other.to_owned(), vm).to_pyresult(vm)
remainder: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PyBytes>() {
a.mod_(b.to_owned(), vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
Expand Down
20 changes: 7 additions & 13 deletions vm/src/builtins/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
PyComparisonValue,
},
identifier,
protocol::{PyNumber, PyNumberMethods},
protocol::PyNumberMethods,
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
};
Expand Down Expand Up @@ -418,16 +418,10 @@ impl Hashable for PyComplex {
impl AsNumber for PyComplex {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
add: Some(|number, other, vm| {
PyComplex::number_op(number, other, |a, b, _vm| a + b, vm)
}),
subtract: Some(|number, other, vm| {
PyComplex::number_op(number, other, |a, b, _vm| a - b, vm)
}),
multiply: Some(|number, other, vm| {
PyComplex::number_op(number, other, |a, b, _vm| a * b, vm)
}),
power: Some(|number, other, vm| PyComplex::number_op(number, other, inner_pow, vm)),
add: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a + b, vm)),
subtract: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a - b, vm)),
multiply: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a * b, vm)),
power: Some(|a, b, vm| PyComplex::number_op(a, b, inner_pow, vm)),
negative: Some(|number, vm| {
let value = PyComplex::number_downcast(number).value;
(-value).to_pyresult(vm)
Expand Down Expand Up @@ -494,12 +488,12 @@ impl Representable for PyComplex {
}

impl PyComplex {
fn number_op<F, R>(number: PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
where
F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R,
R: ToPyResult,
{
if let (Some(a), Some(b)) = (to_op_complex(number.obj, vm)?, to_op_complex(other, vm)?) {
if let (Some(a), Some(b)) = (to_op_complex(a, vm)?, to_op_complex(b, vm)?) {
op(a, b, vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
Expand Down
138 changes: 42 additions & 96 deletions vm/src/builtins/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,16 @@ impl AsSequence for PyDict {
impl AsNumber for PyDict {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
or: Some(|num, args, vm| {
if let Some(num) = num.obj.downcast_ref::<PyDict>() {
PyDict::or(num, args.to_pyobject(vm), vm)
or: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PyDict>() {
PyDict::or(a, b.to_pyobject(vm), vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
inplace_or: Some(|num, args, vm| {
if let Some(num) = num.obj.downcast_ref::<PyDict>() {
PyDict::ior(num.to_owned(), args.to_pyobject(vm), vm).map(|d| d.into())
inplace_or: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PyDict>() {
PyDict::ior(a.to_owned(), b.to_pyobject(vm), vm).map(|d| d.into())
} else {
Ok(vm.ctx.not_implemented())
}
Expand Down Expand Up @@ -1169,51 +1169,10 @@ impl AsSequence for PyDictKeys {
impl AsNumber for PyDictKeys {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
subtract: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num
.difference(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
}
.into_pyobject(vm))
}),
and: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num
.intersection(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
}
.into_pyobject(vm))
}),
xor: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num.symmetric_difference(
ArgIterable::try_from_object(vm, args.to_owned())?,
vm,
)?,
}
.into_pyobject(vm))
}),
or: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num.union(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
}
.into_pyobject(vm))
}),
subtract: Some(inner_set_number_subtract),
and: Some(inner_set_number_and),
xor: Some(inner_set_number_xor),
or: Some(inner_set_number_or),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
Expand Down Expand Up @@ -1288,51 +1247,10 @@ impl AsSequence for PyDictItems {
impl AsNumber for PyDictItems {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
subtract: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num
.difference(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
}
.into_pyobject(vm))
}),
and: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num
.intersection(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
}
.into_pyobject(vm))
}),
xor: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num.symmetric_difference(
ArgIterable::try_from_object(vm, args.to_owned())?,
vm,
)?,
}
.into_pyobject(vm))
}),
or: Some(|num, args, vm| {
let num = PySetInner::from_iter(
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
vm,
)?;
Ok(PySet {
inner: num.union(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
}
.into_pyobject(vm))
}),
subtract: Some(inner_set_number_subtract),
and: Some(inner_set_number_and),
xor: Some(inner_set_number_xor),
or: Some(inner_set_number_or),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
Expand All @@ -1358,6 +1276,34 @@ impl AsSequence for PyDictValues {
}
}

fn inner_set_number_op<F>(a: &PyObject, b: &PyObject, f: F, vm: &VirtualMachine) -> PyResult
where
F: FnOnce(PySetInner, ArgIterable) -> PyResult<PySetInner>,
{
let a = PySetInner::from_iter(
ArgIterable::try_from_object(vm, a.to_owned())?.iter(vm)?,
vm,
)?;
let b = ArgIterable::try_from_object(vm, b.to_owned())?;
Ok(PySet { inner: f(a, b)? }.into_pyobject(vm))
}

fn inner_set_number_subtract(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
inner_set_number_op(a, b, |a, b| a.difference(b, vm), vm)
}

fn inner_set_number_and(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
inner_set_number_op(a, b, |a, b| a.intersection(b, vm), vm)
}

fn inner_set_number_xor(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
inner_set_number_op(a, b, |a, b| a.symmetric_difference(b, vm), vm)
}

fn inner_set_number_or(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
inner_set_number_op(a, b, |a, b| a.union(b, vm), vm)
}

pub(crate) fn init(context: &Context) {
PyDict::extend_class(context, context.types.dict_type);
PyDictKeys::extend_class(context, context.types.dict_keys_type);
Expand Down
40 changes: 20 additions & 20 deletions vm/src/builtins/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
PyArithmeticValue::{self, *},
PyComparisonValue,
},
protocol::{PyNumber, PyNumberMethods},
protocol::PyNumberMethods,
types::{AsNumber, Callable, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
TryFromBorrowedObject, TryFromObject, VirtualMachine,
Expand Down Expand Up @@ -544,29 +544,29 @@ impl Hashable for PyFloat {
impl AsNumber for PyFloat {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
add: Some(|num, other, vm| PyFloat::number_op(num, other, |a, b, _vm| a + b, vm)),
subtract: Some(|num, other, vm| PyFloat::number_op(num, other, |a, b, _vm| a - b, vm)),
multiply: Some(|num, other, vm| PyFloat::number_op(num, other, |a, b, _vm| a * b, vm)),
remainder: Some(|num, other, vm| PyFloat::number_op(num, other, inner_mod, vm)),
divmod: Some(|num, other, vm| PyFloat::number_op(num, other, inner_divmod, vm)),
power: Some(|num, other, vm| PyFloat::number_op(num, other, float_pow, vm)),
negative: Some(|num, vm| {
let value = PyFloat::number_downcast(num).value;
add: Some(|a, b, vm| PyFloat::number_op(a, b, |a, b, _vm| a + b, vm)),
subtract: Some(|a, b, vm| PyFloat::number_op(a, b, |a, b, _vm| a - b, vm)),
multiply: Some(|a, b, vm| PyFloat::number_op(a, b, |a, b, _vm| a * b, vm)),
remainder: Some(|a, b, vm| PyFloat::number_op(a, b, inner_mod, vm)),
divmod: Some(|a, b, vm| PyFloat::number_op(a, b, inner_divmod, vm)),
power: Some(|a, b, vm| PyFloat::number_op(a, b, float_pow, vm)),
negative: Some(|a, vm| {
let value = PyFloat::number_downcast(a).value;
(-value).to_pyresult(vm)
}),
positive: Some(|num, vm| PyFloat::number_downcast_exact(num, vm).to_pyresult(vm)),
absolute: Some(|num, vm| {
let value = PyFloat::number_downcast(num).value;
positive: Some(|a, vm| PyFloat::number_downcast_exact(a, vm).to_pyresult(vm)),
absolute: Some(|a, vm| {
let value = PyFloat::number_downcast(a).value;
value.abs().to_pyresult(vm)
}),
boolean: Some(|num, _vm| Ok(PyFloat::number_downcast(num).value.is_zero())),
int: Some(|num, vm| {
let value = PyFloat::number_downcast(num).value;
boolean: Some(|a, _vm| Ok(PyFloat::number_downcast(a).value.is_zero())),
int: Some(|a, vm| {
let value = PyFloat::number_downcast(a).value;
try_to_bigint(value, vm).map(|x| vm.ctx.new_int(x))
}),
float: Some(|num, vm| Ok(PyFloat::number_downcast_exact(num, vm))),
floor_divide: Some(|num, other, vm| PyFloat::number_op(num, other, inner_floordiv, vm)),
true_divide: Some(|num, other, vm| PyFloat::number_op(num, other, inner_div, vm)),
float: Some(|a, vm| Ok(PyFloat::number_downcast_exact(a, vm))),
floor_divide: Some(|a, b, vm| PyFloat::number_op(a, b, inner_floordiv, vm)),
true_divide: Some(|a, b, vm| PyFloat::number_op(a, b, inner_div, vm)),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
Expand All @@ -586,12 +586,12 @@ impl Representable for PyFloat {
}

impl PyFloat {
fn number_op<F, R>(number: PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
where
F: FnOnce(f64, f64, &VirtualMachine) -> R,
R: ToPyResult,
{
if let (Some(a), Some(b)) = (to_op_float(number.obj, vm)?, to_op_float(other, vm)?) {
if let (Some(a), Some(b)) = (to_op_float(a, vm)?, to_op_float(b, vm)?) {
op(a, b, vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
Expand Down
Loading