Skip to content

Commit 4ce0a7c

Browse files
authored
Merge pull request RustPython#586 from RustPython/joey/float-range-any-payload
Convert more objects to `Any` payload
2 parents 5d28f9b + 2d71f6d commit 4ce0a7c

File tree

8 files changed

+185
-115
lines changed

8 files changed

+185
-115
lines changed

vm/src/obj/objbool.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::objfloat::PyFloat;
12
use super::objstr::PyString;
23
use super::objtype;
34
use crate::pyobject::{
@@ -16,9 +17,11 @@ pub fn boolval(vm: &mut VirtualMachine, obj: PyObjectRef) -> Result<bool, PyObje
1617
if let Some(s) = obj.payload::<PyString>() {
1718
return Ok(!s.value.is_empty());
1819
}
20+
if let Some(value) = obj.payload::<PyFloat>() {
21+
return Ok(*value != PyFloat::from(0.0));
22+
}
1923
let result = match obj.payload {
2024
PyObjectPayload::Integer { ref value } => !value.is_zero(),
21-
PyObjectPayload::Float { value } => value != 0.0,
2225
PyObjectPayload::Sequence { ref elements } => !elements.borrow().is_empty(),
2326
PyObjectPayload::Dict { ref elements } => !elements.borrow().is_empty(),
2427
PyObjectPayload::None { .. } => false,

vm/src/obj/objbytearray.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,47 @@
11
//! Implementation of the python bytearray object.
22
33
use std::cell::RefCell;
4+
use std::ops::{Deref, DerefMut};
45

5-
use crate::pyobject::{PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyResult, TypeProtocol};
6+
use crate::pyobject::{
7+
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult,
8+
TypeProtocol,
9+
};
610

711
use super::objint;
812

9-
use super::objbytes::get_mut_value;
10-
use super::objbytes::get_value;
1113
use super::objtype;
1214
use crate::vm::VirtualMachine;
1315
use num_traits::ToPrimitive;
1416

17+
#[derive(Debug)]
18+
pub struct PyByteArray {
19+
// TODO: shouldn't be public
20+
pub value: RefCell<Vec<u8>>,
21+
}
22+
23+
impl PyByteArray {
24+
pub fn new(data: Vec<u8>) -> Self {
25+
PyByteArray {
26+
value: RefCell::new(data),
27+
}
28+
}
29+
}
30+
31+
impl PyObjectPayload2 for PyByteArray {
32+
fn required_type(ctx: &PyContext) -> PyObjectRef {
33+
ctx.bytearray_type()
34+
}
35+
}
36+
37+
pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<u8>> + 'a {
38+
obj.payload::<PyByteArray>().unwrap().value.borrow()
39+
}
40+
41+
pub fn get_mut_value<'a>(obj: &'a PyObjectRef) -> impl DerefMut<Target = Vec<u8>> + 'a {
42+
obj.payload::<PyByteArray>().unwrap().value.borrow_mut()
43+
}
44+
1545
// Binary data support
1646

1747
/// Fill bytearray class methods dictionary.
@@ -143,8 +173,8 @@ fn bytearray_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
143173
vec![]
144174
};
145175
Ok(PyObject::new(
146-
PyObjectPayload::Bytes {
147-
value: RefCell::new(value),
176+
PyObjectPayload::AnyRustValue {
177+
value: Box::new(PyByteArray::new(value)),
148178
},
149179
cls.clone(),
150180
))
@@ -290,13 +320,8 @@ fn bytearray_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
290320

291321
fn bytearray_clear(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
292322
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytearray_type()))]);
293-
match zelf.payload {
294-
PyObjectPayload::Bytes { ref value } => {
295-
value.borrow_mut().clear();
296-
Ok(vm.get_none())
297-
}
298-
_ => panic!("Bytearray has incorrect payload."),
299-
}
323+
get_mut_value(zelf).clear();
324+
Ok(vm.get_none())
300325
}
301326

302327
fn bytearray_pop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

vm/src/obj/objbytes.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
1-
use std::cell::{Cell, RefCell};
1+
use std::cell::Cell;
22
use std::hash::{Hash, Hasher};
33
use std::ops::Deref;
4-
use std::ops::DerefMut;
54

65
use super::objint;
76
use super::objtype;
87
use crate::pyobject::{
9-
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
8+
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult,
9+
TypeProtocol,
1010
};
1111
use crate::vm::VirtualMachine;
1212
use num_traits::ToPrimitive;
1313

14+
#[derive(Debug)]
15+
pub struct PyBytes {
16+
value: Vec<u8>,
17+
}
18+
19+
impl PyBytes {
20+
pub fn new(data: Vec<u8>) -> Self {
21+
PyBytes { value: data }
22+
}
23+
}
24+
25+
impl Deref for PyBytes {
26+
type Target = [u8];
27+
28+
fn deref(&self) -> &[u8] {
29+
&self.value
30+
}
31+
}
32+
33+
impl PyObjectPayload2 for PyBytes {
34+
fn required_type(ctx: &PyContext) -> PyObjectRef {
35+
ctx.bytes_type()
36+
}
37+
}
38+
1439
// Binary data support
1540

1641
// Fill bytes class methods:
@@ -71,8 +96,8 @@ fn bytes_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
7196
};
7297

7398
Ok(PyObject::new(
74-
PyObjectPayload::Bytes {
75-
value: RefCell::new(value),
99+
PyObjectPayload::AnyRustValue {
100+
value: Box::new(PyBytes::new(value)),
76101
},
77102
cls.clone(),
78103
))
@@ -170,19 +195,7 @@ fn bytes_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
170195
}
171196

172197
pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<u8>> + 'a {
173-
if let PyObjectPayload::Bytes { ref value } = obj.payload {
174-
value.borrow()
175-
} else {
176-
panic!("Inner error getting bytearray {:?}", obj);
177-
}
178-
}
179-
180-
pub fn get_mut_value<'a>(obj: &'a PyObjectRef) -> impl DerefMut<Target = Vec<u8>> + 'a {
181-
if let PyObjectPayload::Bytes { ref value } = obj.payload {
182-
value.borrow_mut()
183-
} else {
184-
panic!("Inner error getting bytearray {:?}", obj);
185-
}
198+
&obj.payload::<PyBytes>().unwrap().value
186199
}
187200

188201
fn bytes_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

vm/src/obj/objfloat.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,30 @@ use super::objint;
33
use super::objstr;
44
use super::objtype;
55
use crate::pyobject::{
6-
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
6+
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult,
7+
TypeProtocol,
78
};
89
use crate::vm::VirtualMachine;
910
use num_bigint::ToBigInt;
1011
use num_traits::ToPrimitive;
1112

13+
#[derive(Debug, Copy, Clone, PartialEq)]
14+
pub struct PyFloat {
15+
value: f64,
16+
}
17+
18+
impl PyObjectPayload2 for PyFloat {
19+
fn required_type(ctx: &PyContext) -> PyObjectRef {
20+
ctx.float_type()
21+
}
22+
}
23+
24+
impl From<f64> for PyFloat {
25+
fn from(value: f64) -> Self {
26+
PyFloat { value }
27+
}
28+
}
29+
1230
fn float_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
1331
arg_check!(vm, args, required = [(float, Some(vm.ctx.float_type()))]);
1432
let v = get_value(float);
@@ -50,16 +68,18 @@ fn float_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
5068
let type_name = objtype::get_type_name(&arg.typ());
5169
return Err(vm.new_type_error(format!("can't convert {} to float", type_name)));
5270
};
53-
Ok(PyObject::new(PyObjectPayload::Float { value }, cls.clone()))
71+
72+
Ok(PyObject::new(
73+
PyObjectPayload::AnyRustValue {
74+
value: Box::new(PyFloat { value }),
75+
},
76+
cls.clone(),
77+
))
5478
}
5579

5680
// Retrieve inner float value:
5781
pub fn get_value(obj: &PyObjectRef) -> f64 {
58-
if let PyObjectPayload::Float { value } = &obj.payload {
59-
*value
60-
} else {
61-
panic!("Inner error getting float: {}", obj);
62-
}
82+
obj.payload::<PyFloat>().unwrap().value
6383
}
6484

6585
pub fn make_float(vm: &mut VirtualMachine, obj: &PyObjectRef) -> PyResult<f64> {

vm/src/obj/objiter.rs

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
* Various types to support iteration.
33
*/
44

5-
use super::objbool;
65
use crate::pyobject::{
76
PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
87
};
98
use crate::vm::VirtualMachine;
10-
// use super::objstr;
11-
use super::objtype; // Required for arg_check! to use isinstance
9+
10+
use super::objbool;
11+
use super::objbytearray::PyByteArray;
12+
use super::objbytes::PyBytes;
13+
use super::objrange::PyRange;
14+
use super::objtype;
1215

1316
/*
1417
* This helper function is called at multiple places. First, it is called
@@ -129,38 +132,43 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
129132
iterated_obj: ref iterated_obj_ref,
130133
} = iter.payload
131134
{
132-
match iterated_obj_ref.payload {
133-
PyObjectPayload::Sequence { ref elements } => {
134-
if position.get() < elements.borrow().len() {
135-
let obj_ref = elements.borrow()[position.get()].clone();
136-
position.set(position.get() + 1);
137-
Ok(obj_ref)
138-
} else {
139-
Err(new_stop_iteration(vm))
140-
}
135+
if let Some(range) = iterated_obj_ref.payload::<PyRange>() {
136+
if let Some(int) = range.get(position.get()) {
137+
position.set(position.get() + 1);
138+
Ok(vm.ctx.new_int(int))
139+
} else {
140+
Err(new_stop_iteration(vm))
141141
}
142-
143-
PyObjectPayload::Range { ref range } => {
144-
if let Some(int) = range.get(position.get()) {
145-
position.set(position.get() + 1);
146-
Ok(vm.ctx.new_int(int))
147-
} else {
148-
Err(new_stop_iteration(vm))
149-
}
142+
} else if let Some(bytes) = iterated_obj_ref.payload::<PyBytes>() {
143+
if position.get() < bytes.len() {
144+
let obj_ref = vm.ctx.new_int(bytes[position.get()]);
145+
position.set(position.get() + 1);
146+
Ok(obj_ref)
147+
} else {
148+
Err(new_stop_iteration(vm))
150149
}
151-
152-
PyObjectPayload::Bytes { ref value } => {
153-
if position.get() < value.borrow().len() {
154-
let obj_ref = vm.ctx.new_int(value.borrow()[position.get()]);
155-
position.set(position.get() + 1);
156-
Ok(obj_ref)
157-
} else {
158-
Err(new_stop_iteration(vm))
159-
}
150+
} else if let Some(bytes) = iterated_obj_ref.payload::<PyByteArray>() {
151+
if position.get() < bytes.value.borrow().len() {
152+
let obj_ref = vm.ctx.new_int(bytes.value.borrow()[position.get()]);
153+
position.set(position.get() + 1);
154+
Ok(obj_ref)
155+
} else {
156+
Err(new_stop_iteration(vm))
160157
}
161-
162-
_ => {
163-
panic!("NOT IMPL");
158+
} else {
159+
match iterated_obj_ref.payload {
160+
PyObjectPayload::Sequence { ref elements } => {
161+
if position.get() < elements.borrow().len() {
162+
let obj_ref = elements.borrow()[position.get()].clone();
163+
position.set(position.get() + 1);
164+
Ok(obj_ref)
165+
} else {
166+
Err(new_stop_iteration(vm))
167+
}
168+
}
169+
_ => {
170+
panic!("NOT IMPL");
171+
}
164172
}
165173
}
166174
} else {

0 commit comments

Comments
 (0)