Skip to content

Commit 657d025

Browse files
committed
Change syntax of attributes in FromArgs proc macro
1 parent 3ca387b commit 657d025

File tree

4 files changed

+164
-109
lines changed

4 files changed

+164
-109
lines changed

derive/src/lib.rs

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,177 @@ extern crate proc_macro;
33
use proc_macro::TokenStream;
44
use proc_macro2::TokenStream as TokenStream2;
55
use quote::quote;
6-
use syn::{Data, DeriveInput, Field, Fields};
6+
use syn::{Attribute, Data, DeriveInput, Expr, Field, Fields, Ident, Lit, Meta, NestedMeta};
77

8-
#[proc_macro_derive(FromArgs, attributes(positional, keyword))]
8+
#[proc_macro_derive(FromArgs, attributes(pyarg))]
99
pub fn derive_from_args(input: TokenStream) -> TokenStream {
1010
let ast: DeriveInput = syn::parse(input).unwrap();
1111

1212
let gen = impl_from_args(&ast);
1313
gen.to_string().parse().unwrap()
1414
}
1515

16-
enum ArgType {
16+
enum ArgKind {
1717
Positional,
1818
PositionalKeyword,
1919
Keyword,
2020
}
2121

22-
fn generate_field(field: &Field) -> TokenStream2 {
23-
let arg_type = if let Some(attr) = field.attrs.first() {
24-
if attr.path.is_ident("positional") {
25-
ArgType::Positional
26-
} else if attr.path.is_ident("keyword") {
27-
ArgType::Keyword
22+
impl ArgKind {
23+
fn from_ident(ident: &Ident) -> ArgKind {
24+
if ident == "positional" {
25+
ArgKind::Positional
26+
} else if ident == "positional_keyword" {
27+
ArgKind::PositionalKeyword
28+
} else if ident == "keyword" {
29+
ArgKind::Keyword
2830
} else {
2931
panic!("Unrecognised attribute")
3032
}
33+
}
34+
}
35+
36+
struct ArgAttribute {
37+
kind: ArgKind,
38+
default: Option<Expr>,
39+
optional: bool,
40+
}
41+
42+
impl ArgAttribute {
43+
fn from_attribute(attr: &Attribute) -> Option<ArgAttribute> {
44+
if !attr.path.is_ident("pyarg") {
45+
return None;
46+
}
47+
48+
match attr.parse_meta().unwrap() {
49+
Meta::List(list) => {
50+
let mut iter = list.nested.iter();
51+
let first_arg = iter.next().expect("at least one argument in pyarg list");
52+
let kind = match first_arg {
53+
NestedMeta::Meta(Meta::Word(ident)) => ArgKind::from_ident(ident),
54+
_ => panic!("Bad syntax for first pyarg attribute argument"),
55+
};
56+
57+
let mut attribute = ArgAttribute {
58+
kind,
59+
default: None,
60+
optional: false,
61+
};
62+
63+
while let Some(arg) = iter.next() {
64+
attribute.parse_argument(arg);
65+
}
66+
67+
assert!(
68+
attribute.default.is_none() || !attribute.optional,
69+
"Can't set both a default value and optional"
70+
);
71+
72+
Some(attribute)
73+
}
74+
_ => panic!("Bad syntax for pyarg attribute"),
75+
}
76+
}
77+
78+
fn parse_argument(&mut self, arg: &NestedMeta) {
79+
match arg {
80+
NestedMeta::Meta(Meta::Word(ident)) => {
81+
if ident == "default" {
82+
assert!(self.default.is_none(), "Default already set");
83+
let expr = syn::parse_str::<Expr>("Default::default()").unwrap();
84+
self.default = Some(expr);
85+
} else if ident == "optional" {
86+
self.optional = true;
87+
} else {
88+
panic!("Unrecognised pyarg attribute '{}'", ident);
89+
}
90+
}
91+
NestedMeta::Meta(Meta::NameValue(name_value)) => {
92+
if name_value.ident == "default" {
93+
assert!(self.default.is_none(), "Default already set");
94+
95+
match name_value.lit {
96+
Lit::Str(ref val) => {
97+
let expr = val
98+
.parse::<Expr>()
99+
.expect("a valid expression for default argument");
100+
self.default = Some(expr);
101+
}
102+
_ => panic!("Expected string value for default argument"),
103+
}
104+
} else if name_value.ident == "optional" {
105+
match name_value.lit {
106+
Lit::Bool(ref val) => {
107+
self.optional = val.value;
108+
}
109+
_ => panic!("Expected boolean value for optional argument"),
110+
}
111+
} else {
112+
panic!("Unrecognised pyarg attribute '{}'", name_value.ident);
113+
}
114+
}
115+
_ => panic!("Bad syntax for first pyarg attribute argument"),
116+
};
117+
}
118+
}
119+
120+
fn generate_field(field: &Field) -> TokenStream2 {
121+
let mut pyarg_attrs = field
122+
.attrs
123+
.iter()
124+
.filter_map(ArgAttribute::from_attribute)
125+
.collect::<Vec<_>>();
126+
let attr = if pyarg_attrs.is_empty() {
127+
ArgAttribute {
128+
kind: ArgKind::PositionalKeyword,
129+
default: None,
130+
optional: false,
131+
}
132+
} else if pyarg_attrs.len() == 1 {
133+
pyarg_attrs.remove(0)
31134
} else {
32-
ArgType::PositionalKeyword
135+
panic!(
136+
"Multiple pyarg attributes on field '{}'",
137+
field.ident.as_ref().unwrap()
138+
);
33139
};
34140

35141
let name = &field.ident;
36-
match arg_type {
37-
ArgType::Positional => {
142+
let middle = quote! {
143+
.map(|x| crate::pyobject::TryFromObject::try_from_object(vm, x)).transpose()?
144+
};
145+
let ending = if let Some(default) = attr.default {
146+
quote! {
147+
.unwrap_or_else(|| #default)
148+
}
149+
} else {
150+
let err = match attr.kind {
151+
ArgKind::Positional | ArgKind::PositionalKeyword => {
152+
quote!(crate::function::ArgumentError::TooFewArgs)
153+
}
154+
ArgKind::Keyword => quote!(crate::function::ArgumentError::RequiredKeywordArgument(
155+
stringify!(#name)
156+
)),
157+
};
158+
quote! {
159+
.ok_or_else(|| #err)?
160+
}
161+
};
162+
163+
match attr.kind {
164+
ArgKind::Positional => {
38165
quote! {
39-
#name: args.take_positional(vm)?,
166+
#name: args.take_positional()#middle#ending,
40167
}
41168
}
42-
ArgType::PositionalKeyword => {
169+
ArgKind::PositionalKeyword => {
43170
quote! {
44-
#name: args.take_positional_keyword(vm, stringify!(#name))?,
171+
#name: args.take_positional_keyword(stringify!(#name))#middle#ending,
45172
}
46173
}
47-
ArgType::Keyword => {
174+
ArgKind::Keyword => {
48175
quote! {
49-
#name: args.take_keyword(vm, stringify!(#name))?,
176+
#name: args.take_keyword(stringify!(#name))#middle#ending,
50177
}
51178
}
52179
}

vm/src/builtins.rs

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -559,22 +559,14 @@ fn builtin_pow(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
559559
}
560560
}
561561

562-
// Idea: Should we have a 'default' attribute, so we don't have to use OptionalArg's in this case
563-
//#[derive(Debug, FromArgs)]
564-
//pub struct PrintOptions {
565-
// #[keyword] #[default(None)] sep: Option<PyStringRef>,
566-
// #[keyword] #[default(None)] end: Option<PyStringRef>,
567-
// #[keyword] #[default(flush)] flush: bool,
568-
//}
569-
570562
#[derive(Debug, FromArgs)]
571563
pub struct PrintOptions {
572-
#[keyword]
573-
sep: OptionalArg<Option<PyStringRef>>,
574-
#[keyword]
575-
end: OptionalArg<Option<PyStringRef>>,
576-
#[keyword]
577-
flush: OptionalArg<bool>,
564+
#[pyarg(keyword, default = "None")]
565+
sep: Option<PyStringRef>,
566+
#[pyarg(keyword, default = "None")]
567+
end: Option<PyStringRef>,
568+
#[pyarg(keyword, default = "false")]
569+
flush: bool,
578570
}
579571

580572
pub fn builtin_print(objects: Args, options: PrintOptions, vm: &VirtualMachine) -> PyResult<()> {
@@ -584,7 +576,7 @@ pub fn builtin_print(objects: Args, options: PrintOptions, vm: &VirtualMachine)
584576
for object in objects {
585577
if first {
586578
first = false;
587-
} else if let OptionalArg::Present(Some(ref sep)) = options.sep {
579+
} else if let Some(ref sep) = options.sep {
588580
write!(stdout_lock, "{}", sep.value).unwrap();
589581
} else {
590582
write!(stdout_lock, " ").unwrap();
@@ -593,13 +585,13 @@ pub fn builtin_print(objects: Args, options: PrintOptions, vm: &VirtualMachine)
593585
write!(stdout_lock, "{}", s).unwrap();
594586
}
595587

596-
if let OptionalArg::Present(Some(end)) = options.end {
588+
if let Some(end) = options.end {
597589
write!(stdout_lock, "{}", end.value).unwrap();
598590
} else {
599591
writeln!(stdout_lock).unwrap();
600592
}
601593

602-
if options.flush.into_option().unwrap_or(false) {
594+
if options.flush {
603595
stdout_lock.flush().unwrap();
604596
}
605597

vm/src/function.rs

Lines changed: 9 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,19 @@ impl PyFuncArgs {
109109
}
110110
}
111111

112-
pub fn next_positional(&mut self) -> Option<PyObjectRef> {
112+
pub fn take_positional(&mut self) -> Option<PyObjectRef> {
113113
if self.args.is_empty() {
114114
None
115115
} else {
116116
Some(self.args.remove(0))
117117
}
118118
}
119119

120-
fn extract_keyword(&mut self, name: &str) -> Option<PyObjectRef> {
120+
pub fn take_positional_keyword(&mut self, name: &str) -> Option<PyObjectRef> {
121+
self.take_positional().or_else(|| self.take_keyword(name))
122+
}
123+
124+
pub fn take_keyword(&mut self, name: &str) -> Option<PyObjectRef> {
121125
// TODO: change kwarg representation so this scan isn't necessary
122126
if let Some(index) = self
123127
.kwargs
@@ -130,49 +134,6 @@ impl PyFuncArgs {
130134
}
131135
}
132136

133-
pub fn take_positional<H: ArgHandler>(
134-
&mut self,
135-
vm: &VirtualMachine,
136-
) -> Result<H, ArgumentError> {
137-
if let Some(arg) = self.next_positional() {
138-
H::from_arg(vm, arg).map_err(|err| ArgumentError::Exception(err))
139-
} else if let Some(default) = H::DEFAULT {
140-
Ok(default)
141-
} else {
142-
Err(ArgumentError::TooFewArgs)
143-
}
144-
}
145-
146-
pub fn take_positional_keyword<H: ArgHandler>(
147-
&mut self,
148-
vm: &VirtualMachine,
149-
name: &str,
150-
) -> Result<H, ArgumentError> {
151-
if let Some(arg) = self.next_positional() {
152-
H::from_arg(vm, arg).map_err(|err| ArgumentError::Exception(err))
153-
} else if let Some(arg) = self.extract_keyword(name) {
154-
H::from_arg(vm, arg).map_err(|err| ArgumentError::Exception(err))
155-
} else if let Some(default) = H::DEFAULT {
156-
Ok(default)
157-
} else {
158-
Err(ArgumentError::TooFewArgs)
159-
}
160-
}
161-
162-
pub fn take_keyword<H: ArgHandler>(
163-
&mut self,
164-
vm: &VirtualMachine,
165-
name: &str,
166-
) -> Result<H, ArgumentError> {
167-
if let Some(arg) = self.extract_keyword(name) {
168-
H::from_arg(vm, arg).map_err(|err| ArgumentError::Exception(err))
169-
} else if let Some(default) = H::DEFAULT {
170-
Ok(default)
171-
} else {
172-
Err(ArgumentError::RequiredKeywordArgument(name.to_string()))
173-
}
174-
}
175-
176137
pub fn remaining_keyword<'a>(&'a mut self) -> impl Iterator<Item = (String, PyObjectRef)> + 'a {
177138
self.kwargs.drain(..)
178139
}
@@ -265,32 +226,6 @@ pub trait FromArgs: Sized {
265226
fn from_args(vm: &VirtualMachine, args: &mut PyFuncArgs) -> Result<Self, ArgumentError>;
266227
}
267228

268-
/// Handling behaviour when the argument is and isn't present, used to implement OptionalArg.
269-
pub trait ArgHandler: Sized {
270-
/// Default value that will be used if the argument isn't present, or None in which case the a
271-
/// appropriate error is returned.
272-
const DEFAULT: Option<Self>;
273-
274-
/// Converts the arg argument when it is present
275-
fn from_arg(vm: &VirtualMachine, object: PyObjectRef) -> PyResult<Self>;
276-
}
277-
278-
impl<T: TryFromObject> ArgHandler for OptionalArg<T> {
279-
const DEFAULT: Option<Self> = Some(Missing);
280-
281-
fn from_arg(vm: &VirtualMachine, object: PyObjectRef) -> PyResult<Self> {
282-
T::try_from_object(vm, object).map(|x| Present(x))
283-
}
284-
}
285-
286-
impl<T: TryFromObject> ArgHandler for T {
287-
const DEFAULT: Option<Self> = None;
288-
289-
fn from_arg(vm: &VirtualMachine, object: PyObjectRef) -> PyResult<Self> {
290-
T::try_from_object(vm, object)
291-
}
292-
}
293-
294229
/// A map of keyword arguments to their values.
295230
///
296231
/// A built-in function with a `KwArgs` parameter is analagous to a Python
@@ -345,7 +280,7 @@ where
345280
{
346281
fn from_args(vm: &VirtualMachine, args: &mut PyFuncArgs) -> Result<Self, ArgumentError> {
347282
let mut varargs = Vec::new();
348-
while let Some(value) = args.next_positional() {
283+
while let Some(value) = args.take_positional() {
349284
varargs.push(T::try_from_object(vm, value)?);
350285
}
351286
Ok(Args(varargs))
@@ -370,7 +305,7 @@ where
370305
}
371306

372307
fn from_args(vm: &VirtualMachine, args: &mut PyFuncArgs) -> Result<Self, ArgumentError> {
373-
if let Some(value) = args.next_positional() {
308+
if let Some(value) = args.take_positional() {
374309
Ok(T::try_from_object(vm, value)?)
375310
} else {
376311
Err(ArgumentError::TooFewArgs)
@@ -414,7 +349,7 @@ where
414349
}
415350

416351
fn from_args(vm: &VirtualMachine, args: &mut PyFuncArgs) -> Result<Self, ArgumentError> {
417-
if let Some(value) = args.next_positional() {
352+
if let Some(value) = args.take_positional() {
418353
Ok(Present(T::try_from_object(vm, value)?))
419354
} else {
420355
Ok(Missing)

vm/src/obj/objint.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,9 @@ impl PyIntRef {
382382

383383
#[derive(FromArgs)]
384384
struct IntOptions {
385-
#[positional]
385+
#[pyarg(positional, optional = true)]
386386
val_options: OptionalArg<PyObjectRef>,
387+
#[pyarg(positional_keyword, optional = true)]
387388
base: OptionalArg<u32>,
388389
}
389390

0 commit comments

Comments
 (0)