Skip to content

Commit cb0b6a8

Browse files
Merge pull request RustPython#759 from RustPython/joey/convert-bytes
Convert bytes to new args style
2 parents c5c9181 + c2d04f9 commit cb0b6a8

File tree

2 files changed

+74
-109
lines changed

2 files changed

+74
-109
lines changed

tests/snippets/bytes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
assert b'foobar'.__eq__(2) == NotImplemented
2+
assert b'foobar'.__ne__(2) == NotImplemented
3+
assert b'foobar'.__gt__(2) == NotImplemented
4+
assert b'foobar'.__ge__(2) == NotImplemented
5+
assert b'foobar'.__lt__(2) == NotImplemented
6+
assert b'foobar'.__le__(2) == NotImplemented

vm/src/obj/objbytes.rs

Lines changed: 68 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
use std::cell::Cell;
2+
use std::collections::hash_map::DefaultHasher;
23
use std::hash::{Hash, Hasher};
34
use std::ops::Deref;
45

56
use num_traits::ToPrimitive;
67

7-
use crate::function::{OptionalArg, PyFuncArgs};
8-
use crate::pyobject::{
9-
PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol,
10-
};
8+
use crate::function::OptionalArg;
9+
use crate::pyobject::{PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue};
1110
use crate::vm::VirtualMachine;
1211

1312
use super::objint;
14-
use super::objtype::{self, PyClassRef};
13+
use super::objtype::PyClassRef;
1514

1615
#[derive(Debug)]
1716
pub struct PyBytes {
@@ -57,16 +56,16 @@ pub fn init(context: &PyContext) {
5756
- an integer";
5857

5958
extend_class!(context, bytes_type, {
60-
"__eq__" => context.new_rustfunc(bytes_eq),
61-
"__lt__" => context.new_rustfunc(bytes_lt),
62-
"__le__" => context.new_rustfunc(bytes_le),
63-
"__gt__" => context.new_rustfunc(bytes_gt),
64-
"__ge__" => context.new_rustfunc(bytes_ge),
65-
"__hash__" => context.new_rustfunc(bytes_hash),
6659
"__new__" => context.new_rustfunc(bytes_new),
67-
"__repr__" => context.new_rustfunc(bytes_repr),
68-
"__len__" => context.new_rustfunc(bytes_len),
69-
"__iter__" => context.new_rustfunc(bytes_iter),
60+
"__eq__" => context.new_rustfunc(PyBytesRef::eq),
61+
"__lt__" => context.new_rustfunc(PyBytesRef::lt),
62+
"__le__" => context.new_rustfunc(PyBytesRef::le),
63+
"__gt__" => context.new_rustfunc(PyBytesRef::gt),
64+
"__ge__" => context.new_rustfunc(PyBytesRef::ge),
65+
"__hash__" => context.new_rustfunc(PyBytesRef::hash),
66+
"__repr__" => context.new_rustfunc(PyBytesRef::repr),
67+
"__len__" => context.new_rustfunc(PyBytesRef::len),
68+
"__iter__" => context.new_rustfunc(PyBytesRef::iter),
7069
"__doc__" => context.new_str(bytes_doc.to_string())
7170
});
7271
}
@@ -93,111 +92,71 @@ fn bytes_new(
9392
PyBytes::new(value).into_ref_with_type(vm, cls)
9493
}
9594

96-
fn bytes_eq(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
97-
arg_check!(
98-
vm,
99-
args,
100-
required = [(a, Some(vm.ctx.bytes_type())), (b, None)]
101-
);
102-
103-
let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) {
104-
get_value(a).to_vec() == get_value(b).to_vec()
105-
} else {
106-
false
107-
};
108-
Ok(vm.ctx.new_bool(result))
109-
}
110-
111-
fn bytes_ge(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
112-
arg_check!(
113-
vm,
114-
args,
115-
required = [(a, Some(vm.ctx.bytes_type())), (b, None)]
116-
);
117-
118-
let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) {
119-
get_value(a).to_vec() >= get_value(b).to_vec()
120-
} else {
121-
return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>'", a, b)));
122-
};
123-
Ok(vm.ctx.new_bool(result))
124-
}
125-
126-
fn bytes_gt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
127-
arg_check!(
128-
vm,
129-
args,
130-
required = [(a, Some(vm.ctx.bytes_type())), (b, None)]
131-
);
95+
impl PyBytesRef {
96+
fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
97+
if let Ok(other) = other.downcast::<PyBytes>() {
98+
vm.ctx.new_bool(self.value == other.value)
99+
} else {
100+
vm.ctx.not_implemented()
101+
}
102+
}
132103

133-
let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) {
134-
get_value(a).to_vec() > get_value(b).to_vec()
135-
} else {
136-
return Err(vm.new_type_error(format!("Cannot compare {} and {} using '>='", a, b)));
137-
};
138-
Ok(vm.ctx.new_bool(result))
139-
}
104+
fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
105+
if let Ok(other) = other.downcast::<PyBytes>() {
106+
vm.ctx.new_bool(self.value >= other.value)
107+
} else {
108+
vm.ctx.not_implemented()
109+
}
110+
}
140111

141-
fn bytes_le(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
142-
arg_check!(
143-
vm,
144-
args,
145-
required = [(a, Some(vm.ctx.bytes_type())), (b, None)]
146-
);
112+
fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
113+
if let Ok(other) = other.downcast::<PyBytes>() {
114+
vm.ctx.new_bool(self.value > other.value)
115+
} else {
116+
vm.ctx.not_implemented()
117+
}
118+
}
147119

148-
let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) {
149-
get_value(a).to_vec() <= get_value(b).to_vec()
150-
} else {
151-
return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<'", a, b)));
152-
};
153-
Ok(vm.ctx.new_bool(result))
154-
}
120+
fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
121+
if let Ok(other) = other.downcast::<PyBytes>() {
122+
vm.ctx.new_bool(self.value <= other.value)
123+
} else {
124+
vm.ctx.not_implemented()
125+
}
126+
}
155127

156-
fn bytes_lt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
157-
arg_check!(
158-
vm,
159-
args,
160-
required = [(a, Some(vm.ctx.bytes_type())), (b, None)]
161-
);
128+
fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
129+
if let Ok(other) = other.downcast::<PyBytes>() {
130+
vm.ctx.new_bool(self.value < other.value)
131+
} else {
132+
vm.ctx.not_implemented()
133+
}
134+
}
162135

163-
let result = if objtype::isinstance(b, &vm.ctx.bytes_type()) {
164-
get_value(a).to_vec() < get_value(b).to_vec()
165-
} else {
166-
return Err(vm.new_type_error(format!("Cannot compare {} and {} using '<='", a, b)));
167-
};
168-
Ok(vm.ctx.new_bool(result))
169-
}
136+
fn len(self, _vm: &VirtualMachine) -> usize {
137+
self.value.len()
138+
}
170139

171-
fn bytes_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
172-
arg_check!(vm, args, required = [(a, Some(vm.ctx.bytes_type()))]);
140+
fn hash(self, _vm: &VirtualMachine) -> u64 {
141+
let mut hasher = DefaultHasher::new();
142+
self.value.hash(&mut hasher);
143+
hasher.finish()
144+
}
173145

174-
let byte_vec = get_value(a).to_vec();
175-
Ok(vm.ctx.new_int(byte_vec.len()))
176-
}
146+
fn repr(self, _vm: &VirtualMachine) -> String {
147+
// TODO: don't just unwrap
148+
let data = String::from_utf8(self.value.clone()).unwrap();
149+
format!("b'{}'", data)
150+
}
177151

178-
fn bytes_hash(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
179-
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytes_type()))]);
180-
let data = get_value(zelf);
181-
let mut hasher = std::collections::hash_map::DefaultHasher::new();
182-
data.hash(&mut hasher);
183-
let hash = hasher.finish();
184-
Ok(vm.ctx.new_int(hash))
152+
fn iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue {
153+
PyIteratorValue {
154+
position: Cell::new(0),
155+
iterated_obj: obj.into_object(),
156+
}
157+
}
185158
}
186159

187160
pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<u8>> + 'a {
188161
&obj.payload::<PyBytes>().unwrap().value
189162
}
190-
191-
fn bytes_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
192-
arg_check!(vm, args, required = [(obj, Some(vm.ctx.bytes_type()))]);
193-
let value = get_value(obj);
194-
let data = String::from_utf8(value.to_vec()).unwrap();
195-
Ok(vm.new_str(format!("b'{}'", data)))
196-
}
197-
198-
fn bytes_iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue {
199-
PyIteratorValue {
200-
position: Cell::new(0),
201-
iterated_obj: obj.into_object(),
202-
}
203-
}

0 commit comments

Comments
 (0)