Skip to content

Commit ea665cb

Browse files
authored
Merge pull request #4615 from xiaozhiyan/binaryops-with-number-protocol
Improve: binary ops with Number Protocol
2 parents 1fceeab + 2287720 commit ea665cb

29 files changed

+1255
-737
lines changed

Lib/test/test_collections.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ def __contains__(self, key):
259259
d = c.new_child(b=20, c=30)
260260
self.assertEqual(d.maps, [{'b': 20, 'c': 30}, {'a': 1, 'b': 2}])
261261

262+
# TODO: RUSTPYTHON
263+
@unittest.expectedFailure
262264
def test_union_operators(self):
263265
cm1 = ChainMap(dict(a=1, b=2), dict(c=3, d=4))
264266
cm2 = ChainMap(dict(a=10, e=5), dict(b=20, d=4))

Lib/test/test_enum.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,9 +2466,7 @@ class Color(StrMixin, AllMixin, Flag):
24662466
self.assertEqual(Color.ALL.value, 7)
24672467
self.assertEqual(str(Color.BLUE), 'blue')
24682468

2469-
# TODO: RUSTPYTHON
2470-
@unittest.expectedFailure
2471-
@unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, inconsistent test result on Windows due to threading")
2469+
@unittest.skip("TODO: RUSTPYTHON, inconsistent test result on Windows due to threading")
24722470
@threading_helper.reap_threads
24732471
def test_unique_composite(self):
24742472
# override __eq__ to be identity only

Lib/test/test_xml_dom_minicompat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def test_emptynodelist___add__(self):
3535
node_list = EmptyNodeList() + NodeList()
3636
self.assertEqual(node_list, NodeList())
3737

38-
# TODO: RUSTPYTHON
39-
@unittest.expectedFailure
4038
def test_emptynodelist___radd__(self):
4139
node_list = [1,2] + EmptyNodeList()
4240
self.assertEqual(node_list, [1,2])

common/src/str.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ pub mod levenshtein {
245245
if a == b {
246246
return 0;
247247
}
248-
if (b'A'..=b'Z').contains(&a) {
248+
if a.is_ascii_uppercase() {
249249
a += b'a' - b'A';
250250
}
251-
if (b'A'..=b'Z').contains(&b) {
251+
if b.is_ascii_uppercase() {
252252
b += b'a' - b'A';
253253
}
254254
if a == b {

derive-impl/src/pyclass.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,27 @@ where
548548
other
549549
),
550550
};
551-
quote_spanned! { ident.span() =>
552-
class.set_str_attr(
553-
#py_name,
554-
ctx.make_funcdef(#py_name, Self::#ident)
555-
#doc
556-
#build_func,
557-
ctx,
558-
);
551+
if py_name.starts_with("__") && py_name.ends_with("__") {
552+
let name_ident = Ident::new(&py_name, ident.span());
553+
quote_spanned! { ident.span() =>
554+
class.set_attr(
555+
ctx.names.#name_ident,
556+
ctx.make_funcdef(#py_name, Self::#ident)
557+
#doc
558+
#build_func
559+
.into(),
560+
);
561+
}
562+
} else {
563+
quote_spanned! { ident.span() =>
564+
class.set_str_attr(
565+
#py_name,
566+
ctx.make_funcdef(#py_name, Self::#ident)
567+
#doc
568+
#build_func,
569+
ctx,
570+
);
571+
}
559572
}
560573
};
561574

rust-toolchain

Lines changed: 0 additions & 1 deletion
This file was deleted.

rust-toolchain.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[toolchain]
2+
channel = "stable"

stdlib/src/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
66
let array = module
77
.get_attr("array", vm)
88
.expect("Expect array has array type.");
9+
array.init_builtin_number_slots(&vm.ctx);
910

1011
let collections_abc = vm
1112
.import("collections.abc", None, 0)

vm/src/builtins/bool.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
use super::{PyInt, PyStrRef, PyType, PyTypeRef};
22
use crate::{
3-
class::PyClassImpl, convert::ToPyObject, function::OptionalArg, identifier, types::Constructor,
3+
class::PyClassImpl,
4+
convert::{ToPyObject, ToPyResult},
5+
function::OptionalArg,
6+
identifier,
7+
protocol::PyNumberMethods,
8+
types::{AsNumber, Constructor},
49
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject,
510
VirtualMachine,
611
};
@@ -102,7 +107,7 @@ impl Constructor for PyBool {
102107
}
103108
}
104109

105-
#[pyclass(with(Constructor))]
110+
#[pyclass(with(Constructor, AsNumber))]
106111
impl PyBool {
107112
#[pymethod(magic)]
108113
fn repr(zelf: bool, vm: &VirtualMachine) -> PyStrRef {
@@ -166,6 +171,24 @@ impl PyBool {
166171
}
167172
}
168173

174+
impl AsNumber for PyBool {
175+
fn as_number() -> &'static PyNumberMethods {
176+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
177+
and: Some(|number, other, vm| {
178+
PyBool::and(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
179+
}),
180+
xor: Some(|number, other, vm| {
181+
PyBool::xor(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
182+
}),
183+
or: Some(|number, other, vm| {
184+
PyBool::or(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
185+
}),
186+
..PyInt::AS_NUMBER
187+
};
188+
&AS_NUMBER
189+
}
190+
}
191+
169192
pub(crate) fn init(context: &Context) {
170193
PyBool::extend_class(context, context.types.bool_type);
171194
}

vm/src/builtins/bytearray.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use crate::{
3737
VirtualMachine,
3838
};
3939
use bstr::ByteSlice;
40-
use once_cell::sync::Lazy;
4140
use std::mem::size_of;
4241

4342
#[pyclass(module = false, name = "bytearray", unhashable = true)]
@@ -859,14 +858,16 @@ impl AsSequence for PyByteArray {
859858

860859
impl AsNumber for PyByteArray {
861860
fn as_number() -> &'static PyNumberMethods {
862-
static AS_NUMBER: Lazy<PyNumberMethods> = Lazy::new(|| PyNumberMethods {
863-
remainder: atomic_func!(|number, other, vm| {
864-
PyByteArray::number_downcast(number)
865-
.mod_(other.to_owned(), vm)
866-
.to_pyresult(vm)
861+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
862+
remainder: Some(|number, other, vm| {
863+
if let Some(number) = number.obj.downcast_ref::<PyByteArray>() {
864+
number.mod_(other.to_owned(), vm).to_pyresult(vm)
865+
} else {
866+
Ok(vm.ctx.not_implemented())
867+
}
867868
}),
868869
..PyNumberMethods::NOT_IMPLEMENTED
869-
});
870+
};
870871
&AS_NUMBER
871872
}
872873
}

0 commit comments

Comments
 (0)