Skip to content

poc for py:: prefix #3645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ast/asdl_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def gen_classdef(self, name, fields, attrs, depth, base="AstNode"):
self.emit(f"struct {structname};", depth)
self.emit("#[pyimpl(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{", depth)
self.emit(f"#[extend_class]", depth + 1)
self.emit(f"#[py::extend_class]", depth + 1)
self.emit("fn extend_class_with_fields(ctx: &Context, class: &'static Py<PyType>) {", depth + 1)
fields = ",".join(f"ctx.new_str(ascii!({json.dumps(f.name)})).into()" for f in fields)
self.emit(f'class.set_attr(identifier!(ctx, _fields), ctx.new_list(vec![{fields}]).into());', depth + 2)
Expand Down
16 changes: 5 additions & 11 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ pub fn pyclass(
) -> proc_macro::TokenStream {
let attr = parse_macro_input!(attr as AttributeArgs);
let item = parse_macro_input!(item as Item);
result_to_tokens(pyclass::impl_pyclass(attr, item))
if matches!(item, syn::Item::Impl(_) | syn::Item::Trait(_)) {
result_to_tokens(pyclass::impl_pyimpl(attr, item))
} else {
result_to_tokens(pyclass::impl_pyclass(attr, item))
}
}

/// This macro serves a goal of generating multiple
Expand Down Expand Up @@ -76,16 +80,6 @@ pub fn pyexception(
result_to_tokens(pyclass::impl_pyexception(attr, item))
}

#[proc_macro_attribute]
pub fn pyimpl(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let attr = parse_macro_input!(attr as AttributeArgs);
let item = parse_macro_input!(item as Item);
result_to_tokens(pyclass::impl_pyimpl(attr, item))
}

#[proc_macro_attribute]
pub fn pymodule(
attr: proc_macro::TokenStream,
Expand Down
41 changes: 18 additions & 23 deletions derive/src/pyclass.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::Diagnostic;
use crate::util::{
path_eq, pyclass_ident_and_attrs, text_signature, ClassItemMeta, ContentItem, ContentItemInner,
ErrorVec, ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta, ALL_ALLOWED_NAMES,
path_eq, pyclass_ident_and_attrs, text_signature, AttributeExt, ClassItemMeta, ContentItem,
ContentItemInner, ErrorVec, ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta,
ALL_ALLOWED_NAMES,
};
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
Expand Down Expand Up @@ -35,7 +36,7 @@ impl std::fmt::Display for AttrName {
Self::GetSet => "pyproperty",
Self::Slot => "pyslot",
Self::Attr => "pyattr",
Self::ExtendClass => "extend_class",
Self::ExtendClass => "pyextend_class",
};
s.fmt(f)
}
Expand All @@ -51,7 +52,7 @@ impl FromStr for AttrName {
"pyproperty" => Self::GetSet,
"pyslot" => Self::Slot,
"pyattr" => Self::Attr,
"extend_class" => Self::ExtendClass,
"pyextend_class" => Self::ExtendClass,
s => {
return Err(s.to_owned());
}
Expand Down Expand Up @@ -373,7 +374,7 @@ pub(crate) fn impl_define_exception(exc_def: PyExceptionDef) -> Result<TokenStre
}
}

#[pyimpl(flags(BASETYPE, HAS_DICT))]
#[pyclass(flags(BASETYPE, HAS_DICT))]
impl #class_name {
#[pyslot]
pub(crate) fn slot_new(
Expand Down Expand Up @@ -418,7 +419,7 @@ struct AttributeItem {
inner: ContentItemInner<AttrName>,
}

/// #[extend_class]
/// #[py::extend_class]
struct ExtendClassItem {
inner: ContentItemInner<AttrName>,
}
Expand Down Expand Up @@ -954,7 +955,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
let path = match meta {
NestedMeta::Meta(Meta::Path(path)) => path,
meta => {
bail_span!(meta, "#[pyimpl(with(...))] arguments should be paths")
bail_span!(meta, "#[pyclass(with(...))] arguments should be paths")
}
};
let (extend_class, extend_slots) = if path_eq(&path, "PyRef") {
Expand Down Expand Up @@ -988,12 +989,12 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
} else {
bail_span!(
path,
"#[pyimpl(flags(...))] arguments should be ident"
"#[pyclass(flags(...))] arguments should be ident"
)
}
}
meta => {
bail_span!(meta, "#[pyimpl(flags(...))] arguments should be ident")
bail_span!(meta, "#[pyclass(flags(...))] arguments should be ident")
}
}
}
Expand Down Expand Up @@ -1061,39 +1062,33 @@ where
while let Some((_, attr)) = iter.peek() {
// take all cfgs but no py items
let attr = *attr;
let attr_name = if let Some(ident) = attr.get_ident() {
ident.to_string()
} else {
continue;
};
if attr_name == "cfg" {
let attr_path = attr.path_string();

if attr_path == "cfg" {
cfgs.push(attr.clone());
} else if ALL_ALLOWED_NAMES.contains(&attr_name.as_str()) {
} else if attr_path.starts_with("py") {
break;
}
iter.next();
}

for (i, attr) in iter {
// take py items but no cfgs
let attr_name = if let Some(ident) = attr.get_ident() {
ident.to_string()
} else {
continue;
};
if attr_name == "cfg" {
let attr_path = attr.path_string();
if attr_path == "cfg" {
return Err(syn::Error::new_spanned(
attr,
"#[py*] items must be placed under `cfgs`",
));
}
let attr_name = attr_path.replace("::", "");
let attr_name = match AttrName::from_str(attr_name.as_str()) {
Ok(name) => name,
Err(wrong_name) => {
if ALL_ALLOWED_NAMES.contains(&attr_name.as_str()) {
return Err(syn::Error::new_spanned(
attr,
format!("#[pyimpl] doesn't accept #[{}]", wrong_name),
format!("#[pyclass] doesn't accept #[{}]", wrong_name),
));
} else {
continue;
Expand Down
23 changes: 11 additions & 12 deletions derive/src/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ pub fn impl_pymodule(attr: AttributeArgs, module_item: Item) -> Result<TokenStre

// collect to context
for item in items.iter_mut() {
if matches!(item, Item::Impl(_) | Item::Trait(_)) {
// #[pyimpl] cases
continue;
}
let r = item.try_split_attr_mut(|attrs, item| {
let (pyitems, cfgs) = attrs_to_module_items(attrs, module_item_new)?;
for pyitem in pyitems.iter().rev() {
Expand Down Expand Up @@ -170,13 +174,11 @@ where
while let Some((_, attr)) = iter.peek() {
// take all cfgs but no py items
let attr = *attr;
if let Some(ident) = attr.get_ident() {
let attr_name = ident.to_string();
if attr_name == "cfg" {
cfgs.push(attr.clone());
} else if ALL_ALLOWED_NAMES.contains(&attr_name.as_str()) {
break;
}
let attr_path = attr.path_string();
if attr_path == "cfg" {
cfgs.push(attr.clone());
} else if attr_path.starts_with("py") {
break;
}
iter.next();
}
Expand All @@ -185,17 +187,14 @@ where
let mut pyattrs = Vec::new();
for (i, attr) in iter {
// take py items but no cfgs
let attr_name = if let Some(ident) = attr.get_ident() {
ident.to_string()
} else {
continue;
};
let attr_name = attr.path_string();
if attr_name == "cfg" {
return Err(syn::Error::new_spanned(
attr,
"#[py*] items must be placed under `cfgs`",
));
}
let attr_name = attr_name.replace("::", "");

let attr_name = match AttrName::from_str(attr_name.as_str()) {
Ok(name) => name,
Expand Down
29 changes: 23 additions & 6 deletions derive/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ struct NurseryItem {
sort_order: usize,
}

pub fn path_to_string(path: &syn::Path) -> String {
let x = format!("{}", quote! {#path});
x.replace(' ', "")
}

#[derive(Default)]
pub(crate) struct ItemNursery(Vec<NurseryItem>);

pub(crate) struct ValidatedItemNursery(ItemNursery);
pub(crate) struct ValidItemNursery(ItemNursery);

impl ItemNursery {
pub fn add_item(
Expand All @@ -58,7 +63,7 @@ impl ItemNursery {
Ok(())
}

pub fn validate(self) -> Result<ValidatedItemNursery> {
pub fn validate(self) -> Result<ValidItemNursery> {
let mut by_name: HashSet<(String, Vec<Attribute>)> = HashSet::new();
for item in &self.0 {
for py_name in &item.py_names {
Expand All @@ -71,11 +76,11 @@ impl ItemNursery {
}
}
}
Ok(ValidatedItemNursery(self))
Ok(ValidItemNursery(self))
}
}

impl ToTokens for ValidatedItemNursery {
impl ToTokens for ValidItemNursery {
fn to_tokens(&self, tokens: &mut TokenStream) {
let mut sorted = self.0 .0.clone();
sorted.sort_by(|a, b| a.sort_order.cmp(&b.sort_order));
Expand Down Expand Up @@ -379,6 +384,7 @@ pub(crate) fn path_eq(path: &Path, s: &str) -> bool {
}

pub(crate) trait AttributeExt: SynAttributeExt {
fn path_string(&self) -> String;
fn promoted_nested(&self) -> Result<PunctuatedNestedMeta>;
fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)>;
fn try_remove_name(&mut self, name: &str) -> Result<Option<syn::NestedMeta>>;
Expand All @@ -388,9 +394,12 @@ pub(crate) trait AttributeExt: SynAttributeExt {
}

impl AttributeExt for Attribute {
fn path_string(&self) -> String {
path_to_string(&self.path)
}
fn promoted_nested(&self) -> Result<PunctuatedNestedMeta> {
let list = self.promoted_list().map_err(|mut e| {
let name = self.get_ident().unwrap().to_string();
let name = self.path_string();
e.combine(syn::Error::new_spanned(
self,
format!(
Expand All @@ -404,7 +413,15 @@ impl AttributeExt for Attribute {
Ok(list.nested)
}
fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)> {
Ok((self.get_ident().unwrap(), self.promoted_nested()?))
let ident = self.get_ident().unwrap_or_else(|| {
&self
.path
.segments
.last()
.expect("py:: paths always have segment")
.ident
});
Ok((ident, self.promoted_nested()?))
}

fn try_remove_name(&mut self, item_name: &str) -> Result<Option<syn::NestedMeta>> {
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ mod array {
}
}

#[pyimpl(
#[pyclass(
flags(BASETYPE),
with(Comparable, AsBuffer, AsMapping, Iterable, Constructor)
)]
Expand Down Expand Up @@ -1282,7 +1282,7 @@ mod array {
array: PyArrayRef,
}

#[pyimpl(with(IterNext))]
#[pyclass(with(IterNext))]
impl PyArrayIter {}

impl IterNextIterable for PyArrayIter {}
Expand Down
6 changes: 3 additions & 3 deletions stdlib/src/contextvars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod _contextvars {
#[derive(Debug, PyPayload)]
struct PyContext {} // not to confuse with vm::Context

#[pyimpl(with(Initializer))]
#[pyclass(with(Initializer))]
impl PyContext {
#[pymethod]
fn run(
Expand Down Expand Up @@ -99,7 +99,7 @@ mod _contextvars {
default: OptionalArg<PyObjectRef>,
}

#[pyimpl(with(Initializer))]
#[pyclass(with(Initializer))]
impl ContextVar {
#[pyproperty]
fn name(&self) -> String {
Expand Down Expand Up @@ -172,7 +172,7 @@ mod _contextvars {
old_value: PyObjectRef,
}

#[pyimpl(with(Initializer))]
#[pyclass(with(Initializer))]
impl ContextToken {
#[pyproperty]
fn var(&self, _vm: &VirtualMachine) -> PyObjectRef {
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ mod _csv {
}
}

#[pyimpl(with(IterNext))]
#[pyclass(with(IterNext))]
impl Reader {}
impl IterNextIterable for Reader {}
impl IterNext for Reader {
Expand Down Expand Up @@ -255,7 +255,7 @@ mod _csv {
}
}

#[pyimpl]
#[pyclass]
impl Writer {
#[pymethod]
fn writerow(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult {
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/hashlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ mod hashlib {
}
}

#[pyimpl]
#[pyclass]
impl PyHasher {
fn new(name: &str, d: HashWrapper) -> Self {
PyHasher {
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ mod _json {
}
}

#[pyimpl(with(Callable, Constructor))]
#[pyclass(with(Callable, Constructor))]
impl JsonScanner {
fn parse(
&self,
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/pyexpat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mod _pyexpat {
vm.invoke(&handler.read().clone(), args).ok();
}

#[pyimpl]
#[pyclass]
impl PyExpatLikeXmlParser {
fn new(vm: &VirtualMachine) -> PyResult<PyExpatLikeXmlParserRef> {
Ok(PyExpatLikeXmlParser {
Expand All @@ -75,7 +75,7 @@ mod _pyexpat {
.into_ref(vm))
}

#[extend_class]
#[py::extend_class]
fn extend_class_with_fields(ctx: &Context, class: &'static Py<PyType>) {
let mut attributes = class.attributes.write();

Expand Down
Loading