From 55f5bf13f02f72454108c376ae70959c1c818801 Mon Sep 17 00:00:00 2001
From: Zhiyan Xiao <zhiyan.xiao@linecorp.com>
Date: Sun, 5 Mar 2023 22:53:04 +0900
Subject: [PATCH] Implement Number Protocol for PyBool

---
 vm/src/builtins/bool.rs | 54 +++++++++++++++++++++++++++++++++++++++--
 1 file changed, 52 insertions(+), 2 deletions(-)

diff --git a/vm/src/builtins/bool.rs b/vm/src/builtins/bool.rs
index 06cf95de50..29b4e46fb0 100644
--- a/vm/src/builtins/bool.rs
+++ b/vm/src/builtins/bool.rs
@@ -1,11 +1,19 @@
 use super::{PyInt, PyStrRef, PyType, PyTypeRef};
 use crate::{
-    class::PyClassImpl, convert::ToPyObject, function::OptionalArg, identifier, types::Constructor,
+    atomic_func,
+    class::PyClassImpl,
+    convert::{ToPyObject, ToPyResult},
+    function::OptionalArg,
+    identifier,
+    protocol::PyNumberMethods,
+    types::{AsNumber, Constructor},
     AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject,
     VirtualMachine,
 };
+use crossbeam_utils::atomic::AtomicCell;
 use num_bigint::Sign;
 use num_traits::Zero;
+use once_cell::sync::Lazy;
 use std::fmt::{Debug, Formatter};
 
 impl ToPyObject for bool {
@@ -102,7 +110,7 @@ impl Constructor for PyBool {
     }
 }
 
-#[pyclass(with(Constructor))]
+#[pyclass(with(Constructor, AsNumber))]
 impl PyBool {
     #[pymethod(magic)]
     fn repr(zelf: bool, vm: &VirtualMachine) -> PyStrRef {
@@ -166,6 +174,48 @@ impl PyBool {
     }
 }
 
+macro_rules! int_method {
+    ($method:ident) => {
+        AtomicCell::new(PyInt::as_number().$method.load().to_owned())
+    };
+}
+
+impl AsNumber for PyBool {
+    fn as_number() -> &'static PyNumberMethods {
+        static AS_NUMBER: Lazy<PyNumberMethods> = Lazy::new(|| PyNumberMethods {
+            add: int_method!(add),
+            subtract: int_method!(subtract),
+            multiply: int_method!(multiply),
+            remainder: int_method!(remainder),
+            divmod: int_method!(divmod),
+            power: int_method!(power),
+            negative: int_method!(negative),
+            positive: int_method!(positive),
+            absolute: int_method!(absolute),
+            boolean: int_method!(boolean),
+            invert: int_method!(invert),
+            lshift: int_method!(lshift),
+            rshift: int_method!(rshift),
+            and: atomic_func!(|number, other, vm| {
+                PyBool::and(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
+            }),
+            xor: atomic_func!(|number, other, vm| {
+                PyBool::xor(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
+            }),
+            or: atomic_func!(|number, other, vm| {
+                PyBool::or(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
+            }),
+            int: int_method!(int),
+            float: int_method!(float),
+            floor_divide: int_method!(floor_divide),
+            true_divide: int_method!(true_divide),
+            index: int_method!(index),
+            ..PyNumberMethods::NOT_IMPLEMENTED
+        });
+        &AS_NUMBER
+    }
+}
+
 pub(crate) fn init(context: &Context) {
     PyBool::extend_class(context, context.types.bool_type);
 }