Skip to content

Commit 0abd8b1

Browse files
committed
Fix pystructseq
1 parent 58a17f3 commit 0abd8b1

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

derive-impl/src/pystructseq.rs

+15-15
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,19 @@ fn field_names(input: &mut DeriveInput) -> Result<(Vec<Ident>, Vec<Ident>)> {
8080
}
8181

8282
pub(crate) fn impl_pystruct_sequence(mut input: DeriveInput) -> Result<TokenStream> {
83-
let (not_skipped_fields, _skipped_fields) = field_names(&mut input)?;
83+
let (not_skipped_fields, skipped_fields) = field_names(&mut input)?;
8484
let ty = &input.ident;
8585
let ret = quote! {
8686
impl ::rustpython_vm::types::PyStructSequence for #ty {
87-
const FIELD_NAMES: &'static [&'static str] = &[#(stringify!(#not_skipped_fields)),*];
87+
const REQUIRED_FIELD_NAMES: &'static [&'static str] = &[#(stringify!(#not_skipped_fields),)*];
88+
const OPTIONAL_FIELD_NAMES: &'static [&'static str] = &[#(stringify!(#skipped_fields),)*];
8889
fn into_tuple(self, vm: &::rustpython_vm::VirtualMachine) -> ::rustpython_vm::builtins::PyTuple {
89-
let items = vec![#(::rustpython_vm::convert::ToPyObject::to_pyobject(
90-
self.#not_skipped_fields,
91-
vm,
92-
)),*];
90+
let items = vec![
91+
#(::rustpython_vm::convert::ToPyObject::to_pyobject(
92+
self.#not_skipped_fields,
93+
vm,
94+
),)*
95+
];
9396
::rustpython_vm::builtins::PyTuple::new_unchecked(items.into_boxed_slice())
9497
}
9598
}
@@ -110,17 +113,14 @@ pub(crate) fn impl_pystruct_sequence_try_from_object(
110113
let ret = quote! {
111114
impl ::rustpython_vm::TryFromObject for #ty {
112115
fn try_from_object(vm: &::rustpython_vm::VirtualMachine, seq: ::rustpython_vm::PyObjectRef) -> ::rustpython_vm::PyResult<Self> {
113-
const LEN: usize = #ty::FIELD_NAMES.len();
114-
let seq = Self::try_elements_from::<LEN>(seq, vm)?;
115-
// TODO: this is possible to be written without iterator
116+
let seq = Self::try_elements_from(seq, vm)?;
116117
let mut iter = seq.into_iter();
117118
Ok(Self {
118-
#(
119-
#not_skipped_fields: iter.next().unwrap().clone().try_into_value(vm)?,
120-
)*
121-
#(
122-
#skipped_fields: vm.ctx.none(),
123-
)*
119+
#(#not_skipped_fields: iter.next().unwrap().clone().try_into_value(vm)?,)*
120+
#(#skipped_fields: match iter.next() {
121+
Some(v) => v.clone().try_into_value(vm)?,
122+
None => vm.ctx.none(),
123+
},)*
124124
})
125125
}
126126
}

vm/src/types/structseq.rs

+25-18
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use crate::{
22
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
3-
builtins::{PyTuple, PyTupleRef, PyType},
3+
builtins::{PyBaseExceptionRef, PyTuple, PyTupleRef, PyType},
44
class::{PyClassImpl, StaticType},
55
vm::Context,
66
};
77

88
#[pyclass]
99
pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
10-
const FIELD_NAMES: &'static [&'static str];
10+
const REQUIRED_FIELD_NAMES: &'static [&'static str];
11+
const OPTIONAL_FIELD_NAMES: &'static [&'static str];
1112

1213
fn into_tuple(self, vm: &VirtualMachine) -> PyTuple;
1314

@@ -17,10 +18,16 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
1718
.unwrap()
1819
}
1920

20-
fn try_elements_from<const FIELD_LEN: usize>(
21-
obj: PyObjectRef,
22-
vm: &VirtualMachine,
23-
) -> PyResult<[PyObjectRef; FIELD_LEN]> {
21+
fn try_elements_from(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Vec<PyObjectRef>> {
22+
#[cold]
23+
fn sequence_length_error(
24+
name: &str,
25+
len: usize,
26+
vm: &VirtualMachine,
27+
) -> PyBaseExceptionRef {
28+
vm.new_type_error(format!("{name} takes a sequence of length {len}"))
29+
}
30+
2431
let typ = Self::static_type();
2532
// if !obj.fast_isinstance(typ) {
2633
// return Err(vm.new_type_error(format!(
@@ -30,13 +37,13 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
3037
// )));
3138
// }
3239
let seq: Vec<PyObjectRef> = obj.try_into_value(vm)?;
33-
let seq: [PyObjectRef; FIELD_LEN] = seq.try_into().map_err(|_| {
34-
vm.new_type_error(format!(
35-
"{} takes a sequence of length {}",
36-
typ.name(),
37-
FIELD_LEN
38-
))
39-
})?;
40+
if seq.len() < Self::REQUIRED_FIELD_NAMES.len() {
41+
return Err(sequence_length_error(
42+
&typ.name(),
43+
Self::REQUIRED_FIELD_NAMES.len(),
44+
vm,
45+
));
46+
}
4047
Ok(seq)
4148
}
4249

@@ -49,14 +56,14 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
4956
let (body, suffix) = if let Some(_guard) =
5057
rustpython_vm::recursion::ReprGuard::enter(vm, zelf.as_object())
5158
{
52-
if Self::FIELD_NAMES.len() == 1 {
59+
if Self::REQUIRED_FIELD_NAMES.len() == 1 {
5360
let value = zelf.first().unwrap();
54-
let formatted = format_field((value, Self::FIELD_NAMES[0]))?;
61+
let formatted = format_field((value, Self::REQUIRED_FIELD_NAMES[0]))?;
5562
(formatted, ",")
5663
} else {
5764
let fields: PyResult<Vec<_>> = zelf
5865
.iter()
59-
.zip(Self::FIELD_NAMES.iter().copied())
66+
.zip(Self::REQUIRED_FIELD_NAMES.iter().copied())
6067
.map(format_field)
6168
.collect();
6269
(fields?.join(", "), "")
@@ -74,7 +81,7 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
7481

7582
#[extend_class]
7683
fn extend_pyclass(ctx: &Context, class: &'static Py<PyType>) {
77-
for (i, &name) in Self::FIELD_NAMES.iter().enumerate() {
84+
for (i, &name) in Self::REQUIRED_FIELD_NAMES.iter().enumerate() {
7885
// cast i to a u8 so there's less to store in the getter closure.
7986
// Hopefully there's not struct sequences with >=256 elements :P
8087
let i = i as u8;
@@ -90,7 +97,7 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
9097
class.set_attr(
9198
identifier!(ctx, __match_args__),
9299
ctx.new_tuple(
93-
Self::FIELD_NAMES
100+
Self::REQUIRED_FIELD_NAMES
94101
.iter()
95102
.map(|&name| ctx.new_str(name).into())
96103
.collect::<Vec<_>>(),

0 commit comments

Comments
 (0)