From b0d7960cc59ed1a9a1a8f794818349024eac3f4f Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Wed, 27 Mar 2019 19:14:37 -0700 Subject: [PATCH 1/4] bytes: move methods to impl block --- vm/src/obj/objbytes.rs | 208 +++++++++++++++++++++-------------------- 1 file changed, 105 insertions(+), 103 deletions(-) diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 8d64cc5076..5b725be501 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -57,16 +57,16 @@ pub fn init(context: &PyContext) { - an integer"; extend_class!(context, bytes_type, { - "__eq__" => context.new_rustfunc(bytes_eq), - "__lt__" => context.new_rustfunc(bytes_lt), - "__le__" => context.new_rustfunc(bytes_le), - "__gt__" => context.new_rustfunc(bytes_gt), - "__ge__" => context.new_rustfunc(bytes_ge), - "__hash__" => context.new_rustfunc(bytes_hash), "__new__" => context.new_rustfunc(bytes_new), - "__repr__" => context.new_rustfunc(bytes_repr), - "__len__" => context.new_rustfunc(bytes_len), - "__iter__" => context.new_rustfunc(bytes_iter), + "__eq__" => context.new_rustfunc(PyBytesRef::eq), + "__lt__" => context.new_rustfunc(PyBytesRef::lt), + "__le__" => context.new_rustfunc(PyBytesRef::le), + "__gt__" => context.new_rustfunc(PyBytesRef::gt), + "__ge__" => context.new_rustfunc(PyBytesRef::ge), + "__hash__" => context.new_rustfunc(PyBytesRef::hash), + "__repr__" => context.new_rustfunc(PyBytesRef::repr), + "__len__" => context.new_rustfunc(PyBytesRef::len), + "__iter__" => context.new_rustfunc(PyBytesRef::iter), "__doc__" => context.new_str(bytes_doc.to_string()) }); } @@ -93,111 +93,113 @@ fn bytes_new( PyBytes::new(value).into_ref_with_type(vm, cls) } -fn bytes_eq(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() == get_value(b).to_vec() - } else { - false - }; - Ok(vm.ctx.new_bool(result)) -} - -fn bytes_ge(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() >= get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>'", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} +impl PyBytesRef { + fn eq(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(a, Some(vm.ctx.bytes_type())), (b, None)] + ); + + let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { + get_value(a).to_vec() == get_value(b).to_vec() + } else { + false + }; + Ok(vm.ctx.new_bool(result)) + } -fn bytes_gt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); + fn ge(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(a, Some(vm.ctx.bytes_type())), (b, None)] + ); + + let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { + get_value(a).to_vec() >= get_value(b).to_vec() + } else { + return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>'", a, b))); + }; + Ok(vm.ctx.new_bool(result)) + } - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() > get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>='", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} + fn gt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(a, Some(vm.ctx.bytes_type())), (b, None)] + ); + + let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { + get_value(a).to_vec() > get_value(b).to_vec() + } else { + return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>='", a, b))); + }; + Ok(vm.ctx.new_bool(result)) + } -fn bytes_le(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); + fn le(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(a, Some(vm.ctx.bytes_type())), (b, None)] + ); + + let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { + get_value(a).to_vec() <= get_value(b).to_vec() + } else { + return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<'", a, b))); + }; + Ok(vm.ctx.new_bool(result)) + } - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() <= get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<'", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} + fn lt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(a, Some(vm.ctx.bytes_type())), (b, None)] + ); + + let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { + get_value(a).to_vec() < get_value(b).to_vec() + } else { + return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<='", a, b))); + }; + Ok(vm.ctx.new_bool(result)) + } -fn bytes_lt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); + fn len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(a, Some(vm.ctx.bytes_type()))]); - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() < get_value(b).to_vec() - } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<='", a, b))); - }; - Ok(vm.ctx.new_bool(result)) -} + let byte_vec = get_value(a).to_vec(); + Ok(vm.ctx.new_int(byte_vec.len())) + } -fn bytes_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(a, Some(vm.ctx.bytes_type()))]); + fn hash(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytes_type()))]); + let data = get_value(zelf); + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + data.hash(&mut hasher); + let hash = hasher.finish(); + Ok(vm.ctx.new_int(hash)) + } - let byte_vec = get_value(a).to_vec(); - Ok(vm.ctx.new_int(byte_vec.len())) -} + fn repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(obj, Some(vm.ctx.bytes_type()))]); + let value = get_value(obj); + let data = String::from_utf8(value.to_vec()).unwrap(); + Ok(vm.new_str(format!("b'{}'", data))) + } -fn bytes_hash(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytes_type()))]); - let data = get_value(zelf); - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - data.hash(&mut hasher); - let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash)) + fn iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue { + PyIteratorValue { + position: Cell::new(0), + iterated_obj: obj.into_object(), + } + } } pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { &obj.payload::().unwrap().value } - -fn bytes_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, Some(vm.ctx.bytes_type()))]); - let value = get_value(obj); - let data = String::from_utf8(value.to_vec()).unwrap(); - Ok(vm.new_str(format!("b'{}'", data))) -} - -fn bytes_iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue { - PyIteratorValue { - position: Cell::new(0), - iterated_obj: obj.into_object(), - } -} From e4272126cf5589ecd24a6a89a77daba364968c49 Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Wed, 27 Mar 2019 19:15:38 -0700 Subject: [PATCH 2/4] bytes: return NotImplemented where appropriate --- vm/src/obj/objbytes.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 5b725be501..bed57157e9 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -104,7 +104,7 @@ impl PyBytesRef { let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { get_value(a).to_vec() == get_value(b).to_vec() } else { - false + return Ok(vm.ctx.not_implemented()); }; Ok(vm.ctx.new_bool(result)) } @@ -119,7 +119,7 @@ impl PyBytesRef { let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { get_value(a).to_vec() >= get_value(b).to_vec() } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>'", a, b))); + return Ok(vm.ctx.not_implemented()); }; Ok(vm.ctx.new_bool(result)) } @@ -134,7 +134,7 @@ impl PyBytesRef { let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { get_value(a).to_vec() > get_value(b).to_vec() } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>='", a, b))); + return Ok(vm.ctx.not_implemented()); }; Ok(vm.ctx.new_bool(result)) } @@ -149,7 +149,7 @@ impl PyBytesRef { let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { get_value(a).to_vec() <= get_value(b).to_vec() } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<'", a, b))); + return Ok(vm.ctx.not_implemented()); }; Ok(vm.ctx.new_bool(result)) } @@ -164,7 +164,7 @@ impl PyBytesRef { let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { get_value(a).to_vec() < get_value(b).to_vec() } else { - return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<='", a, b))); + return Ok(vm.ctx.not_implemented()); }; Ok(vm.ctx.new_bool(result)) } From e0aca86473738ea8cf592fce5b7e4769293b6227 Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Wed, 27 Mar 2019 19:17:27 -0700 Subject: [PATCH 3/4] bytes: convert methods to new args style --- vm/src/obj/objbytes.rs | 121 +++++++++++++---------------------------- 1 file changed, 39 insertions(+), 82 deletions(-) diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index bed57157e9..9390e70704 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -1,17 +1,16 @@ use std::cell::Cell; +use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::ops::Deref; use num_traits::ToPrimitive; -use crate::function::{OptionalArg, PyFuncArgs}; -use crate::pyobject::{ - PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, -}; +use crate::function::OptionalArg; +use crate::pyobject::{PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use super::objint; -use super::objtype::{self, PyClassRef}; +use super::objtype::PyClassRef; #[derive(Debug)] pub struct PyBytes { @@ -94,102 +93,60 @@ fn bytes_new( } impl PyBytesRef { - fn eq(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() == get_value(b).to_vec() + fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value == other.value) } else { - return Ok(vm.ctx.not_implemented()); - }; - Ok(vm.ctx.new_bool(result)) + vm.ctx.not_implemented() + } } - fn ge(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() >= get_value(b).to_vec() + fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value >= other.value) } else { - return Ok(vm.ctx.not_implemented()); - }; - Ok(vm.ctx.new_bool(result)) + vm.ctx.not_implemented() + } } - fn gt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() > get_value(b).to_vec() + fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value > other.value) } else { - return Ok(vm.ctx.not_implemented()); - }; - Ok(vm.ctx.new_bool(result)) + vm.ctx.not_implemented() + } } - fn le(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() <= get_value(b).to_vec() + fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value <= other.value) } else { - return Ok(vm.ctx.not_implemented()); - }; - Ok(vm.ctx.new_bool(result)) + vm.ctx.not_implemented() + } } - fn lt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(a, Some(vm.ctx.bytes_type())), (b, None)] - ); - - let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) { - get_value(a).to_vec() < get_value(b).to_vec() + fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if let Ok(other) = other.downcast::() { + vm.ctx.new_bool(self.value < other.value) } else { - return Ok(vm.ctx.not_implemented()); - }; - Ok(vm.ctx.new_bool(result)) + vm.ctx.not_implemented() + } } - fn len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(a, Some(vm.ctx.bytes_type()))]); - - let byte_vec = get_value(a).to_vec(); - Ok(vm.ctx.new_int(byte_vec.len())) + fn len(self, _vm: &VirtualMachine) -> usize { + self.value.len() } - fn hash(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytes_type()))]); - let data = get_value(zelf); - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - data.hash(&mut hasher); - let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash)) + fn hash(self, _vm: &VirtualMachine) -> u64 { + let mut hasher = DefaultHasher::new(); + self.value.hash(&mut hasher); + hasher.finish() } - fn repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, Some(vm.ctx.bytes_type()))]); - let value = get_value(obj); - let data = String::from_utf8(value.to_vec()).unwrap(); - Ok(vm.new_str(format!("b'{}'", data))) + fn repr(self, _vm: &VirtualMachine) -> String { + // TODO: don't just unwrap + let data = String::from_utf8(self.value.clone()).unwrap(); + format!("b'{}'", data) } fn iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue { From c2d04f97d80a0eb2916ac7c4ce630b43a528f26b Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Sat, 30 Mar 2019 13:33:50 -0700 Subject: [PATCH 4/4] bytes: add tests for NotImplemented --- tests/snippets/bytes.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 tests/snippets/bytes.py diff --git a/tests/snippets/bytes.py b/tests/snippets/bytes.py new file mode 100644 index 0000000000..19ebef68ae --- /dev/null +++ b/tests/snippets/bytes.py @@ -0,0 +1,6 @@ +assert b'foobar'.__eq__(2) == NotImplemented +assert b'foobar'.__ne__(2) == NotImplemented +assert b'foobar'.__gt__(2) == NotImplemented +assert b'foobar'.__ge__(2) == NotImplemented +assert b'foobar'.__lt__(2) == NotImplemented +assert b'foobar'.__le__(2) == NotImplemented