diff --git a/tests/snippets/bytes.py b/tests/snippets/bytes.py new file mode 100644 index 0000000000..19ebef68ae --- /dev/null +++ b/tests/snippets/bytes.py @@ -0,0 +1,6 @@ +assert b'foobar'.__eq__(2) == NotImplemented +assert b'foobar'.__ne__(2) == NotImplemented +assert b'foobar'.__gt__(2) == NotImplemented +assert b'foobar'.__ge__(2) == NotImplemented +assert b'foobar'.__lt__(2) == NotImplemented +assert b'foobar'.__le__(2) == NotImplemented diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 8d64cc5076..9390e70704 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -1,17 +1,16 @@ use std::cell::Cell; +use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::ops::Deref; use num_traits::ToPrimitive; -use crate::function::{OptionalArg, PyFuncArgs}; -use crate::pyobject::{ - PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, -}; +use crate::function::OptionalArg; +use crate::pyobject::{PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use super::objint; -use super::objtype::{self, PyClassRef}; +use super::objtype::PyClassRef; #[derive(Debug)] pub struct PyBytes { @@ -57,16 +56,16 @@ pub fn init(context: &PyContext) { - an integer"; extend_class!(context, bytes_type, { - "__eq__" => context.new_rustfunc(bytes_eq), - "__lt__" => context.new_rustfunc(bytes_lt), - "__le__" => context.new_rustfunc(bytes_le), - "__gt__" => context.new_rustfunc(bytes_gt), - "__ge__" => context.new_rustfunc(bytes_ge), - "__hash__" => context.new_rustfunc(bytes_hash), "__new__" => context.new_rustfunc(bytes_new), - "__repr__" => context.new_rustfunc(bytes_repr), - "__len__" => context.new_rustfunc(bytes_len), - "__iter__" => context.new_rustfunc(bytes_iter), + "__eq__" => context.new_rustfunc(PyBytesRef::eq), + "__lt__" => context.new_rustfunc(PyBytesRef::lt), + "__le__" => context.new_rustfunc(PyBytesRef::le), + "__gt__" => context.new_rustfunc(PyBytesRef::gt), + "__ge__" => context.new_rustfunc(PyBytesRef::ge), + "__hash__" => context.new_rustfunc(PyBytesRef::hash), + "__repr__" => context.new_rustfunc(PyBytesRef::repr), + "__len__" => context.new_rustfunc(PyBytesRef::len), + "__iter__" => context.new_rustfunc(PyBytesRef::iter), "__doc__" => context.new_str(bytes_doc.to_string()) }); } @@ -93,111 +92,71 @@ fn bytes_new( PyBytes::new(value).into_ref_with_type(vm, cls) } -fn bytes_eq(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() == get_value(b).to_vec() - } else { - false - }; - Ok(vm.ctx.new_bool(result)) -} - -fn bytes_ge(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() >= get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>'", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} - -fn bytes_gt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); +impl PyBytesRef { + fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value == other.value) + } else { + vm.ctx.not_implemented() + } + } - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() > get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>='", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} + fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value >= other.value) + } else { + vm.ctx.not_implemented() + } + } -fn bytes_le(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); + fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value > other.value) + } else { + vm.ctx.not_implemented() + } + } - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() <= get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<'", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} + fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value <= other.value) + } else { + vm.ctx.not_implemented() + } + } -fn bytes_lt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); + fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value < other.value) + } else { + vm.ctx.not_implemented() + } + } - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() < get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<='", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} + fn len(self, _vm: &VirtualMachine) -> usize { + self.value.len() + } -fn bytes_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(a, Some(vm.ctx.bytes_type()))]); + fn hash(self, _vm: &VirtualMachine) -> u64 { + let mut hasher = DefaultHasher::new(); + self.value.hash(&mut hasher); + hasher.finish() + } - let byte_vec = get_value(a).to_vec(); - Ok(vm.ctx.new_int(byte_vec.len())) -} + fn repr(self, _vm: &VirtualMachine) -> String { + // TODO: don't just unwrap + let data = String::from_utf8(self.value.clone()).unwrap(); + format!("b'{}'", data) + } -fn bytes_hash(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytes_type()))]); - let data = get_value(zelf); - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - data.hash(&mut hasher); - let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash)) + fn iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue { + PyIteratorValue { + position: Cell::new(0), + iterated_obj: obj.into_object(), + } + } } pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { &obj.payload::().unwrap().value } - -fn bytes_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, Some(vm.ctx.bytes_type()))]); - let value = get_value(obj); - let data = String::from_utf8(value.to_vec()).unwrap(); - Ok(vm.new_str(format!("b'{}'", data))) -} - -fn bytes_iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue { - PyIteratorValue { - position: Cell::new(0), - iterated_obj: obj.into_object(), - } -}