Skip to content

Commit 711b1a6

Browse files
authored
PyTypeFlags::{SEQUENCE,MAPPING} (#6109)
1 parent dae9584 commit 711b1a6

File tree

7 files changed

+95
-23
lines changed

7 files changed

+95
-23
lines changed

vm/src/builtins/dict.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ impl PyDict {
176176
AsMapping,
177177
Representable
178178
),
179-
flags(BASETYPE)
179+
flags(BASETYPE, MAPPING)
180180
)]
181181
impl PyDict {
182182
#[pyclassmethod]

vm/src/builtins/list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ pub type PyListRef = PyRef<PyList>;
109109
AsSequence,
110110
Representable
111111
),
112-
flags(BASETYPE)
112+
flags(BASETYPE, SEQUENCE)
113113
)]
114114
impl PyList {
115115
#[pymethod]

vm/src/builtins/memory.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -538,17 +538,20 @@ impl Py<PyMemoryView> {
538538
}
539539
}
540540

541-
#[pyclass(with(
542-
Py,
543-
Hashable,
544-
Comparable,
545-
AsBuffer,
546-
AsMapping,
547-
AsSequence,
548-
Constructor,
549-
Iterable,
550-
Representable
551-
))]
541+
#[pyclass(
542+
with(
543+
Py,
544+
Hashable,
545+
Comparable,
546+
AsBuffer,
547+
AsMapping,
548+
AsSequence,
549+
Constructor,
550+
Iterable,
551+
Representable
552+
),
553+
flags(SEQUENCE)
554+
)]
552555
impl PyMemoryView {
553556
// TODO: Uncomment when Python adds __class_getitem__ to memoryview
554557
// #[pyclassmethod]

vm/src/builtins/range.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,18 @@ pub fn init(context: &Context) {
174174
PyRangeIterator::extend_class(context, context.types.range_iterator_type);
175175
}
176176

177-
#[pyclass(with(
178-
Py,
179-
AsMapping,
180-
AsSequence,
181-
Hashable,
182-
Comparable,
183-
Iterable,
184-
Representable
185-
))]
177+
#[pyclass(
178+
with(
179+
Py,
180+
AsMapping,
181+
AsSequence,
182+
Hashable,
183+
Comparable,
184+
Iterable,
185+
Representable
186+
),
187+
flags(SEQUENCE)
188+
)]
186189
impl PyRange {
187190
fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
188191
Self {

vm/src/builtins/tuple.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ impl<T> PyTuple<PyRef<T>> {
244244
}
245245

246246
#[pyclass(
247-
flags(BASETYPE),
247+
flags(BASETYPE, SEQUENCE),
248248
with(
249249
AsMapping,
250250
AsSequence,

vm/src/builtins/type.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use crate::{
3131
};
3232
use indexmap::{IndexMap, map::Entry};
3333
use itertools::Itertools;
34+
use num_traits::ToPrimitive;
3435
use std::{borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNull};
3536

3637
#[pyclass(module = false, name = "type", traverse = "manual")]
@@ -231,6 +232,58 @@ impl PyType {
231232
linearise_mro(mros)
232233
}
233234

235+
/// Inherit SEQUENCE and MAPPING flags from base class (CPython: inherit_patma_flags)
236+
fn inherit_patma_flags(slots: &mut PyTypeSlots, base: &PyRef<Self>) {
237+
const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate(
238+
PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(),
239+
);
240+
if !slots.flags.intersects(COLLECTION_FLAGS) {
241+
slots.flags |= base.slots.flags & COLLECTION_FLAGS;
242+
}
243+
}
244+
245+
/// Check for __abc_tpflags__ and set the appropriate flags
246+
/// This checks in attrs and all base classes for __abc_tpflags__
247+
fn check_abc_tpflags(
248+
slots: &mut PyTypeSlots,
249+
attrs: &PyAttributes,
250+
bases: &[PyRef<Self>],
251+
ctx: &Context,
252+
) {
253+
const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate(
254+
PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(),
255+
);
256+
257+
// Don't override if flags are already set
258+
if slots.flags.intersects(COLLECTION_FLAGS) {
259+
return;
260+
}
261+
262+
// First check in our own attributes
263+
let abc_tpflags_name = ctx.intern_str("__abc_tpflags__");
264+
if let Some(abc_tpflags_obj) = attrs.get(abc_tpflags_name) {
265+
if let Some(int_obj) = abc_tpflags_obj.downcast_ref::<crate::builtins::int::PyInt>() {
266+
let flags_val = int_obj.as_bigint().to_i64().unwrap_or(0);
267+
let abc_flags = PyTypeFlags::from_bits_truncate(flags_val as u64);
268+
slots.flags |= abc_flags & COLLECTION_FLAGS;
269+
return;
270+
}
271+
}
272+
273+
// Then check in base classes
274+
for base in bases {
275+
if let Some(abc_tpflags_obj) = base.find_name_in_mro(abc_tpflags_name) {
276+
if let Some(int_obj) = abc_tpflags_obj.downcast_ref::<crate::builtins::int::PyInt>()
277+
{
278+
let flags_val = int_obj.as_bigint().to_i64().unwrap_or(0);
279+
let abc_flags = PyTypeFlags::from_bits_truncate(flags_val as u64);
280+
slots.flags |= abc_flags & COLLECTION_FLAGS;
281+
return;
282+
}
283+
}
284+
}
285+
}
286+
234287
#[allow(clippy::too_many_arguments)]
235288
fn new_heap_inner(
236289
base: PyRef<Self>,
@@ -246,6 +299,13 @@ impl PyType {
246299
if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
247300
slots.flags |= PyTypeFlags::HAS_DICT
248301
}
302+
303+
// Inherit SEQUENCE and MAPPING flags from base class
304+
Self::inherit_patma_flags(&mut slots, &base);
305+
306+
// Check for __abc_tpflags__ from ABCMeta (for collections.abc.Sequence, Mapping, etc.)
307+
Self::check_abc_tpflags(&mut slots, &attrs, &bases, ctx);
308+
249309
if slots.basicsize == 0 {
250310
slots.basicsize = base.slots.basicsize;
251311
}
@@ -297,6 +357,10 @@ impl PyType {
297357
if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
298358
slots.flags |= PyTypeFlags::HAS_DICT
299359
}
360+
361+
// Inherit SEQUENCE and MAPPING flags from base class
362+
Self::inherit_patma_flags(&mut slots, &base);
363+
300364
if slots.basicsize == 0 {
301365
slots.basicsize = base.slots.basicsize;
302366
}

vm/src/types/slot.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ bitflags! {
123123
#[non_exhaustive]
124124
pub struct PyTypeFlags: u64 {
125125
const MANAGED_DICT = 1 << 4;
126+
const SEQUENCE = 1 << 5;
127+
const MAPPING = 1 << 6;
126128
const IMMUTABLETYPE = 1 << 8;
127129
const HEAPTYPE = 1 << 9;
128130
const BASETYPE = 1 << 10;

0 commit comments

Comments
 (0)