Skip to content

Commit ea665cb

Browse files
authored
Merge pull request #4615 from xiaozhiyan/binaryops-with-number-protocol
Improve: binary ops with Number Protocol
2 parents 1fceeab + 2287720 commit ea665cb

29 files changed

+1255
-737
lines changed

Lib/test/test_collections.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ def __contains__(self, key):
259259
d = c.new_child(b=20, c=30)
260260
self.assertEqual(d.maps, [{'b': 20, 'c': 30}, {'a': 1, 'b': 2}])
261261

262+
# TODO: RUSTPYTHON
263+
@unittest.expectedFailure
262264
def test_union_operators(self):
263265
cm1 = ChainMap(dict(a=1, b=2), dict(c=3, d=4))
264266
cm2 = ChainMap(dict(a=10, e=5), dict(b=20, d=4))

Lib/test/test_enum.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,9 +2466,7 @@ class Color(StrMixin, AllMixin, Flag):
24662466
self.assertEqual(Color.ALL.value, 7)
24672467
self.assertEqual(str(Color.BLUE), 'blue')
24682468

2469-
# TODO: RUSTPYTHON
2470-
@unittest.expectedFailure
2471-
@unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, inconsistent test result on Windows due to threading")
2469+
@unittest.skip("TODO: RUSTPYTHON, inconsistent test result on Windows due to threading")
24722470
@threading_helper.reap_threads
24732471
def test_unique_composite(self):
24742472
# override __eq__ to be identity only

Lib/test/test_xml_dom_minicompat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def test_emptynodelist___add__(self):
3535
node_list = EmptyNodeList() + NodeList()
3636
self.assertEqual(node_list, NodeList())
3737

38-
# TODO: RUSTPYTHON
39-
@unittest.expectedFailure
4038
def test_emptynodelist___radd__(self):
4139
node_list = [1,2] + EmptyNodeList()
4240
self.assertEqual(node_list, [1,2])

common/src/str.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ pub mod levenshtein {
245245
if a == b {
246246
return 0;
247247
}
248-
if (b'A'..=b'Z').contains(&a) {
248+
if a.is_ascii_uppercase() {
249249
a += b'a' - b'A';
250250
}
251-
if (b'A'..=b'Z').contains(&b) {
251+
if b.is_ascii_uppercase() {
252252
b += b'a' - b'A';
253253
}
254254
if a == b {

derive-impl/src/pyclass.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,27 @@ where
548548
other
549549
),
550550
};
551-
quote_spanned! { ident.span() =>
552-
class.set_str_attr(
553-
#py_name,
554-
ctx.make_funcdef(#py_name, Self::#ident)
555-
#doc
556-
#build_func,
557-
ctx,
558-
);
551+
if py_name.starts_with("__") && py_name.ends_with("__") {
552+
let name_ident = Ident::new(&py_name, ident.span());
553+
quote_spanned! { ident.span() =>
554+
class.set_attr(
555+
ctx.names.#name_ident,
556+
ctx.make_funcdef(#py_name, Self::#ident)
557+
#doc
558+
#build_func
559+
.into(),
560+
);
561+
}
562+
} else {
563+
quote_spanned! { ident.span() =>
564+
class.set_str_attr(
565+
#py_name,
566+
ctx.make_funcdef(#py_name, Self::#ident)
567+
#doc
568+
#build_func,
569+
ctx,
570+
);
571+
}
559572
}
560573
};
561574

rust-toolchain

Lines changed: 0 additions & 1 deletion
This file was deleted.

rust-toolchain.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[toolchain]
2+
channel = "stable"

stdlib/src/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
66
let array = module
77
.get_attr("array", vm)
88
.expect("Expect array has array type.");
9+
array.init_builtin_number_slots(&vm.ctx);
910

1011
let collections_abc = vm
1112
.import("collections.abc", None, 0)

vm/src/builtins/bool.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
use super::{PyInt, PyStrRef, PyType, PyTypeRef};
22
use crate::{
3-
class::PyClassImpl, convert::ToPyObject, function::OptionalArg, identifier, types::Constructor,
3+
class::PyClassImpl,
4+
convert::{ToPyObject, ToPyResult},
5+
function::OptionalArg,
6+
identifier,
7+
protocol::PyNumberMethods,
8+
types::{AsNumber, Constructor},
49
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject,
510
VirtualMachine,
611
};
@@ -102,7 +107,7 @@ impl Constructor for PyBool {
102107
}
103108
}
104109

105-
#[pyclass(with(Constructor))]
110+
#[pyclass(with(Constructor, AsNumber))]
106111
impl PyBool {
107112
#[pymethod(magic)]
108113
fn repr(zelf: bool, vm: &VirtualMachine) -> PyStrRef {
@@ -166,6 +171,24 @@ impl PyBool {
166171
}
167172
}
168173

174+
impl AsNumber for PyBool {
175+
fn as_number() -> &'static PyNumberMethods {
176+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
177+
and: Some(|number, other, vm| {
178+
PyBool::and(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
179+
}),
180+
xor: Some(|number, other, vm| {
181+
PyBool::xor(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
182+
}),
183+
or: Some(|number, other, vm| {
184+
PyBool::or(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
185+
}),
186+
..PyInt::AS_NUMBER
187+
};
188+
&AS_NUMBER
189+
}
190+
}
191+
169192
pub(crate) fn init(context: &Context) {
170193
PyBool::extend_class(context, context.types.bool_type);
171194
}

vm/src/builtins/bytearray.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use crate::{
3737
VirtualMachine,
3838
};
3939
use bstr::ByteSlice;
40-
use once_cell::sync::Lazy;
4140
use std::mem::size_of;
4241

4342
#[pyclass(module = false, name = "bytearray", unhashable = true)]
@@ -859,14 +858,16 @@ impl AsSequence for PyByteArray {
859858

860859
impl AsNumber for PyByteArray {
861860
fn as_number() -> &'static PyNumberMethods {
862-
static AS_NUMBER: Lazy<PyNumberMethods> = Lazy::new(|| PyNumberMethods {
863-
remainder: atomic_func!(|number, other, vm| {
864-
PyByteArray::number_downcast(number)
865-
.mod_(other.to_owned(), vm)
866-
.to_pyresult(vm)
861+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
862+
remainder: Some(|number, other, vm| {
863+
if let Some(number) = number.obj.downcast_ref::<PyByteArray>() {
864+
number.mod_(other.to_owned(), vm).to_pyresult(vm)
865+
} else {
866+
Ok(vm.ctx.not_implemented())
867+
}
867868
}),
868869
..PyNumberMethods::NOT_IMPLEMENTED
869-
});
870+
};
870871
&AS_NUMBER
871872
}
872873
}

vm/src/builtins/bytes.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,14 +629,16 @@ impl AsSequence for PyBytes {
629629

630630
impl AsNumber for PyBytes {
631631
fn as_number() -> &'static PyNumberMethods {
632-
static AS_NUMBER: Lazy<PyNumberMethods> = Lazy::new(|| PyNumberMethods {
633-
remainder: atomic_func!(|number, other, vm| {
634-
PyBytes::number_downcast(number)
635-
.mod_(other.to_owned(), vm)
636-
.to_pyresult(vm)
632+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
633+
remainder: Some(|number, other, vm| {
634+
if let Some(number) = number.obj.downcast_ref::<PyBytes>() {
635+
number.mod_(other.to_owned(), vm).to_pyresult(vm)
636+
} else {
637+
Ok(vm.ctx.not_implemented())
638+
}
637639
}),
638640
..PyNumberMethods::NOT_IMPLEMENTED
639-
});
641+
};
640642
&AS_NUMBER
641643
}
642644
}

vm/src/builtins/complex.rs

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use super::{float, PyStr, PyType, PyTypeRef};
22
use crate::{
3-
atomic_func,
43
class::PyClassImpl,
54
convert::{ToPyObject, ToPyResult},
65
function::{
@@ -15,7 +14,6 @@ use crate::{
1514
};
1615
use num_complex::Complex64;
1716
use num_traits::Zero;
18-
use once_cell::sync::Lazy;
1917
use rustpython_common::{float_ops, hash};
2018
use std::num::Wrapping;
2119

@@ -454,38 +452,34 @@ impl Hashable for PyComplex {
454452

455453
impl AsNumber for PyComplex {
456454
fn as_number() -> &'static PyNumberMethods {
457-
static AS_NUMBER: Lazy<PyNumberMethods> = Lazy::new(|| PyNumberMethods {
458-
add: atomic_func!(|number, other, vm| {
455+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
456+
add: Some(|number, other, vm| {
459457
PyComplex::number_op(number, other, |a, b, _vm| a + b, vm)
460458
}),
461-
subtract: atomic_func!(|number, other, vm| {
459+
subtract: Some(|number, other, vm| {
462460
PyComplex::number_op(number, other, |a, b, _vm| a - b, vm)
463461
}),
464-
multiply: atomic_func!(|number, other, vm| {
462+
multiply: Some(|number, other, vm| {
465463
PyComplex::number_op(number, other, |a, b, _vm| a * b, vm)
466464
}),
467-
power: atomic_func!(|number, other, vm| PyComplex::number_op(
468-
number, other, inner_pow, vm
469-
)),
470-
negative: atomic_func!(|number, vm| {
465+
power: Some(|number, other, vm| PyComplex::number_op(number, other, inner_pow, vm)),
466+
negative: Some(|number, vm| {
471467
let value = PyComplex::number_downcast(number).value;
472468
(-value).to_pyresult(vm)
473469
}),
474-
positive: atomic_func!(
475-
|number, vm| PyComplex::number_downcast_exact(number, vm).to_pyresult(vm)
476-
),
477-
absolute: atomic_func!(|number, vm| {
470+
positive: Some(|number, vm| {
471+
PyComplex::number_downcast_exact(number, vm).to_pyresult(vm)
472+
}),
473+
absolute: Some(|number, vm| {
478474
let value = PyComplex::number_downcast(number).value;
479475
value.norm().to_pyresult(vm)
480476
}),
481-
boolean: atomic_func!(|number, _vm| Ok(PyComplex::number_downcast(number)
482-
.value
483-
.is_zero())),
484-
true_divide: atomic_func!(|number, other, vm| {
477+
boolean: Some(|number, _vm| Ok(PyComplex::number_downcast(number).value.is_zero())),
478+
true_divide: Some(|number, other, vm| {
485479
PyComplex::number_op(number, other, inner_div, vm)
486480
}),
487481
..PyNumberMethods::NOT_IMPLEMENTED
488-
});
482+
};
489483
&AS_NUMBER
490484
}
491485

0 commit comments

Comments
 (0)