Skip to content

Commit a07857c

Browse files
committed
Impl PySequence Protocol
1 parent 4ede6b1 commit a07857c

File tree

4 files changed

+196
-18
lines changed

4 files changed

+196
-18
lines changed

vm/src/builtins/memory.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,24 @@ use crate::common::{
44
hash::PyHash,
55
lock::OnceCell,
66
rc::PyRc,
7+
static_cell,
78
};
89
use crate::{
910
bytesinner::bytes_to_hex,
1011
function::{FuncArgs, IntoPyObject, OptionalArg},
11-
protocol::{BufferInternal, BufferOptions, PyBuffer, PyMappingMethods},
12+
protocol::{BufferInternal, BufferOptions, PyBuffer, PyMappingMethods, PySequenceMethods},
1213
sliceable::{wrap_index, SaturatedSlice, SequenceIndex},
1314
stdlib::pystruct::FormatSpec,
14-
types::{AsBuffer, AsMapping, Comparable, Constructor, Hashable, PyComparisonOp},
15+
types::{AsBuffer, AsMapping, AsSequence, Comparable, Constructor, Hashable, PyComparisonOp},
1516
utils::Either,
1617
IdProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef,
1718
PyResult, PyValue, TryFromBorrowedObject, TryFromObject, TypeProtocol, VirtualMachine,
1819
};
1920
use crossbeam_utils::atomic::AtomicCell;
2021
use itertools::Itertools;
2122
use num_traits::ToPrimitive;
22-
use std::fmt::Debug;
2323
use std::ops::Deref;
24+
use std::{borrow::Cow, fmt::Debug};
2425

2526
#[derive(FromArgs)]
2627
pub struct PyMemoryViewNewArgs {
@@ -771,6 +772,22 @@ impl AsMapping for PyMemoryView {
771772
}
772773
}
773774

775+
impl AsSequence for PyMemoryView {
776+
fn as_sequence(_zelf: &PyRef<Self>, _vm: &VirtualMachine) -> Cow<'static, PySequenceMethods> {
777+
static_cell! {
778+
static METHODS: PySequenceMethods;
779+
}
780+
Cow::Borrowed(METHODS.get_or_init(|| PySequenceMethods {
781+
length: Some(|zelf, vm| zelf.payload::<Self>().unwrap().len(vm)),
782+
item: Some(|zelf, i, vm| {
783+
let zelf = zelf.clone().downcast::<Self>().unwrap();
784+
Self::getitem_by_idx(zelf, i, vm)
785+
}),
786+
..Default::default()
787+
}))
788+
}
789+
}
790+
774791
impl Comparable for PyMemoryView {
775792
fn cmp(
776793
zelf: &PyRef<Self>,

vm/src/protocol/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
mod buffer;
22
mod iter;
33
mod mapping;
4+
mod sequence;
45

56
pub use buffer::{BufferInternal, BufferOptions, BufferResizeGuard, PyBuffer};
67
pub use iter::{PyIter, PyIterIter, PyIterReturn};
78
pub use mapping::{PyMapping, PyMappingMethods};
9+
pub use sequence::{PySequence, PySequenceMethods};

vm/src/protocol/sequence.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use std::borrow::{Borrow, Cow};
2+
3+
use crate::{IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine};
4+
5+
// Sequence Protocol
6+
// https://docs.python.org/3/c-api/sequence.html
7+
8+
#[derive(Default, Clone)]
9+
pub struct PySequenceMethods {
10+
pub length: Option<fn(&PyObjectRef, &VirtualMachine) -> PyResult<usize>>,
11+
pub concat: Option<fn(&PyObjectRef, &PyObjectRef, &VirtualMachine) -> PyResult<PyObjectRef>>,
12+
pub repeat: Option<fn(&PyObjectRef, usize, &VirtualMachine) -> PyResult<PyObjectRef>>,
13+
pub inplace_concat:
14+
Option<fn(PyObjectRef, &PyObjectRef, &VirtualMachine) -> PyResult<PyObjectRef>>,
15+
pub inplace_repeat: Option<fn(PyObjectRef, usize, &VirtualMachine) -> PyResult<PyObjectRef>>,
16+
pub item: Option<fn(&PyObjectRef, isize, &VirtualMachine) -> PyResult<PyObjectRef>>,
17+
pub ass_item:
18+
Option<fn(PyObjectRef, isize, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
19+
pub contains: Option<fn(&PyObjectRef, &PyObjectRef, &VirtualMachine) -> PyResult<bool>>,
20+
}
21+
22+
pub struct PySequence(PyObjectRef, Cow<'static, PySequenceMethods>);
23+
24+
impl PySequence {
25+
pub fn check(obj: &PyObjectRef, vm: &VirtualMachine) -> bool {
26+
let cls = obj.class();
27+
if cls.is(&vm.ctx.types.dict_type) {
28+
return false;
29+
}
30+
if let Some(f) = cls.mro_find_map(|x| x.slots.as_sequence.load()) {
31+
return f(obj, vm).item.is_some();
32+
}
33+
false
34+
}
35+
36+
pub fn from_object(vm: &VirtualMachine, obj: PyObjectRef) -> Option<Self> {
37+
let cls = obj.class();
38+
if cls.is(&vm.ctx.types.dict_type) {
39+
return None;
40+
}
41+
let f = cls.mro_find_map(|x| x.slots.as_sequence.load())?;
42+
drop(cls);
43+
let methods = f(&obj, vm);
44+
if methods.item.is_some() {
45+
Some(Self(obj, methods))
46+
} else {
47+
None
48+
}
49+
}
50+
51+
pub fn methods(&self) -> &PySequenceMethods {
52+
self.1.borrow()
53+
}
54+
}

vm/src/types/slot.rs

Lines changed: 120 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use crate::common::{hash::PyHash, lock::PyRwLock};
22
use crate::{
33
builtins::{PyInt, PyStrRef, PyType, PyTypeRef},
4-
function::{FromArgs, FuncArgs, IntoPyResult, OptionalArg},
5-
protocol::{PyBuffer, PyIterReturn, PyMappingMethods},
4+
function::{FromArgs, FuncArgs, IntoPyObject, IntoPyResult, OptionalArg},
5+
protocol::{PyBuffer, PyIterReturn, PyMappingMethods, PySequence, PySequenceMethods},
66
utils::Either,
7-
IdProtocol, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol,
8-
VirtualMachine,
7+
IdProtocol, PyArithmeticValue, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue,
8+
TypeProtocol, VirtualMachine,
99
};
1010
use crossbeam_utils::atomic::AtomicCell;
1111
use num_traits::ToPrimitive;
12+
use std::borrow::Cow;
1213
use std::cmp::Ordering;
1314

1415
// The corresponding field in CPython is `tp_` prefixed.
@@ -22,7 +23,7 @@ pub struct PyTypeSlots {
2223

2324
// Method suites for standard classes
2425
// tp_as_number
25-
// tp_as_sequence
26+
pub as_sequence: AtomicCell<Option<AsSequenceFunc>>,
2627
pub as_mapping: AtomicCell<Option<AsMappingFunc>>,
2728

2829
// More standard operations (here for binary compatibility)
@@ -149,17 +150,20 @@ pub(crate) type DescrSetFunc =
149150
fn(PyObjectRef, PyObjectRef, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>;
150151
pub(crate) type NewFunc = fn(PyTypeRef, FuncArgs, &VirtualMachine) -> PyResult;
151152
pub(crate) type DelFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult<()>;
153+
pub(crate) type AsSequenceFunc =
154+
fn(&PyObjectRef, &VirtualMachine) -> Cow<'static, PySequenceMethods>;
155+
156+
macro_rules! then_some_closure {
157+
($cond:expr, $closure:expr) => {
158+
if $cond {
159+
Some($closure)
160+
} else {
161+
None
162+
}
163+
};
164+
}
152165

153166
fn as_mapping_wrapper(zelf: &PyObjectRef, _vm: &VirtualMachine) -> PyMappingMethods {
154-
macro_rules! then_some_closure {
155-
($cond:expr, $closure:expr) => {
156-
if $cond {
157-
Some($closure)
158-
} else {
159-
None
160-
}
161-
};
162-
}
163167
PyMappingMethods {
164168
length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| {
165169
vm.call_special_method(zelf, "__len__", ()).map(|obj| {
@@ -192,6 +196,90 @@ fn as_mapping_wrapper(zelf: &PyObjectRef, _vm: &VirtualMachine) -> PyMappingMeth
192196
}
193197
}
194198

199+
fn as_sequence_wrapper(
200+
zelf: &PyObjectRef,
201+
_vm: &VirtualMachine,
202+
) -> Cow<'static, PySequenceMethods> {
203+
Cow::Owned(PySequenceMethods {
204+
length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| {
205+
vm.obj_len_opt(zelf).unwrap()
206+
}),
207+
concat: then_some_closure!(zelf.has_class_attr("__add__"), |zelf, other, vm| {
208+
if PySequence::check(zelf, vm) && PySequence::check(other, vm) {
209+
let ret = vm.call_special_method(zelf.clone(), "__add__", (other.clone(),))?;
210+
if let PyArithmeticValue::Implemented(obj) = PyArithmeticValue::from_object(vm, ret)
211+
{
212+
return Ok(obj);
213+
}
214+
}
215+
Err(vm.new_type_error(format!("'{}' object can't be concatenated", zelf)))
216+
}),
217+
repeat: then_some_closure!(zelf.has_class_attr("__mul__"), |zelf, n, vm| {
218+
if PySequence::check(zelf, vm) {
219+
let ret =
220+
vm.call_special_method(zelf.clone(), "__mul__", (n.into_pyobject(vm),))?;
221+
if let PyArithmeticValue::Implemented(obj) = PyArithmeticValue::from_object(vm, ret)
222+
{
223+
return Ok(obj);
224+
}
225+
}
226+
Err(vm.new_type_error(format!("'{}' object can't be repeated", zelf)))
227+
}),
228+
inplace_concat: then_some_closure!(
229+
zelf.has_class_attr("__iadd__") || zelf.has_class_attr("__add__"),
230+
|zelf, other, vm| {
231+
if PySequence::check(&zelf, vm) && PySequence::check(other, vm) {
232+
if let Ok(f) = vm.get_special_method(zelf.clone(), "__iadd__")? {
233+
let ret = f.invoke((other.clone(),), vm)?;
234+
if let PyArithmeticValue::Implemented(obj) =
235+
PyArithmeticValue::from_object(vm, ret)
236+
{
237+
return Ok(obj);
238+
}
239+
}
240+
if let Ok(f) = vm.get_special_method(zelf.clone(), "__add__")? {
241+
let ret = f.invoke((other.clone(),), vm)?;
242+
if let PyArithmeticValue::Implemented(obj) =
243+
PyArithmeticValue::from_object(vm, ret)
244+
{
245+
return Ok(obj);
246+
}
247+
}
248+
}
249+
Err(vm.new_type_error(format!("'{}' object can't be concatenated", zelf)))
250+
}
251+
),
252+
inplace_repeat: then_some_closure!(
253+
zelf.has_class_attr("__imul__") || zelf.has_class_attr("__mul__"),
254+
|zelf, n, vm| {
255+
if PySequence::check(&zelf, vm) {
256+
if let Ok(f) = vm.get_special_method(zelf.clone(), "__imul__")? {
257+
let ret = f.invoke((n.into_pyobject(vm),), vm)?;
258+
if let PyArithmeticValue::Implemented(obj) =
259+
PyArithmeticValue::from_object(vm, ret)
260+
{
261+
return Ok(obj);
262+
}
263+
}
264+
if let Ok(f) = vm.get_special_method(zelf.clone(), "__mul__")? {
265+
let ret = f.invoke((n.into_pyobject(vm),), vm)?;
266+
if let PyArithmeticValue::Implemented(obj) =
267+
PyArithmeticValue::from_object(vm, ret)
268+
{
269+
return Ok(obj);
270+
}
271+
}
272+
}
273+
Err(vm.new_type_error(format!("'{}' object can't be repeated", zelf)))
274+
}
275+
),
276+
item: None,
277+
ass_item: None,
278+
// TODO: IterSearch
279+
contains: None,
280+
})
281+
}
282+
195283
fn hash_wrapper(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult<PyHash> {
196284
let hash_obj = vm.call_special_method(zelf.clone(), "__hash__", ())?;
197285
match hash_obj.payload_if_subclass::<PyInt>(vm) {
@@ -291,7 +379,10 @@ impl PyType {
291379
match name {
292380
"__len__" | "__getitem__" | "__setitem__" | "__delitem__" => {
293381
update_slot!(as_mapping, as_mapping_wrapper);
294-
// TODO: need to update sequence protocol too
382+
update_slot!(as_sequence, as_sequence_wrapper);
383+
}
384+
"__add__" | "__iadd__" | "__mul__" | "__imul__" => {
385+
update_slot!(as_sequence, as_sequence_wrapper);
295386
}
296387
"__hash__" => {
297388
update_slot!(hash, hash_wrapper);
@@ -804,6 +895,20 @@ pub trait AsMapping: PyValue {
804895
) -> PyResult<()>;
805896
}
806897

898+
#[pyimpl]
899+
pub trait AsSequence: PyValue {
900+
#[inline]
901+
#[pyslot]
902+
fn slot_as_sequence(
903+
zelf: &PyObjectRef,
904+
vm: &VirtualMachine,
905+
) -> Cow<'static, PySequenceMethods> {
906+
let zelf = unsafe { zelf.downcast_unchecked_ref::<Self>() };
907+
Self::as_sequence(zelf, vm)
908+
}
909+
fn as_sequence(zelf: &PyRef<Self>, vm: &VirtualMachine) -> Cow<'static, PySequenceMethods>;
910+
}
911+
807912
#[pyimpl]
808913
pub trait Iterable: PyValue {
809914
#[pyslot]

0 commit comments

Comments
 (0)