Skip to content

Commit cc9c90b

Browse files
committed
Support for impl Trait in associated type (type_alias_impl_trait)
1 parent 9dc0f16 commit cc9c90b

File tree

1 file changed

+66
-13
lines changed

1 file changed

+66
-13
lines changed

src/expand.rs

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ use crate::parse::Item;
33
use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
44
use proc_macro2::TokenStream;
55
use quote::{format_ident, quote, quote_spanned, ToTokens};
6+
use std::collections::BTreeSet as Set;
67
use syn::punctuated::Punctuated;
7-
use syn::visit_mut::VisitMut;
8+
use syn::visit_mut::{self, VisitMut};
89
use syn::{
910
parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat,
1011
PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound,
11-
WhereClause,
12+
TypePath, WhereClause,
1213
};
1314

1415
macro_rules! parse_quote_spanned {
@@ -34,6 +35,7 @@ enum Context<'a> {
3435
},
3536
Impl {
3637
impl_generics: &'a Generics,
38+
associated_type_impl_traits: &'a Set<Ident>,
3739
},
3840
}
3941

@@ -71,7 +73,7 @@ pub fn expand(input: &mut Item, is_local: bool) {
7173
method.attrs.push(parse_quote!(#[must_use]));
7274
if let Some(block) = block {
7375
has_self |= has_self_in_block(block);
74-
transform_block(sig, block);
76+
transform_block(context, sig, block);
7577
method.attrs.push(lint_suppress_with_body());
7678
} else {
7779
method.attrs.push(lint_suppress_without_body());
@@ -90,16 +92,26 @@ pub fn expand(input: &mut Item, is_local: bool) {
9092
let elided = lifetimes.elided;
9193
input.generics.params = parse_quote!(#(#elided,)* #params);
9294

95+
let mut associated_type_impl_traits = Set::new();
96+
for inner in &input.items {
97+
if let ImplItem::Type(assoc) = inner {
98+
if let Type::ImplTrait(_) = assoc.ty {
99+
associated_type_impl_traits.insert(assoc.ident.clone());
100+
}
101+
}
102+
}
103+
93104
let context = Context::Impl {
94105
impl_generics: &input.generics,
106+
associated_type_impl_traits: &associated_type_impl_traits,
95107
};
96108
for inner in &mut input.items {
97109
if let ImplItem::Method(method) = inner {
98110
let sig = &mut method.sig;
99111
if sig.asyncness.is_some() {
100112
let block = &mut method.block;
101113
let has_self = has_self_in_sig(sig) || has_self_in_block(block);
102-
transform_block(sig, block);
114+
transform_block(context, sig, block);
103115
transform_sig(context, sig, has_self, false, is_local);
104116
method.attrs.push(lint_suppress_with_body());
105117
}
@@ -296,7 +308,7 @@ fn transform_sig(
296308
//
297309
// ___ret
298310
// })
299-
fn transform_block(sig: &mut Signature, block: &mut Block) {
311+
fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
300312
if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
301313
if block.stmts.len() == 1 && item.to_string() == ";" {
302314
return;
@@ -345,18 +357,24 @@ fn transform_block(sig: &mut Signature, block: &mut Block) {
345357
}
346358

347359
let stmts = &block.stmts;
348-
let let_ret = match &sig.output {
360+
let let_ret = match &mut sig.output {
349361
ReturnType::Default => quote_spanned! {block.brace_token.span=>
350362
let _: () = { #(#decls)* #(#stmts)* };
351363
},
352-
ReturnType::Type(_, ret) => quote_spanned! {block.brace_token.span=>
353-
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
354-
return __ret;
364+
ReturnType::Type(_, ret) => {
365+
if contains_associated_type_impl_trait(context, ret) {
366+
quote!(#(#decls)* #(#stmts)*)
367+
} else {
368+
quote_spanned! {block.brace_token.span=>
369+
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
370+
return __ret;
371+
}
372+
let __ret: #ret = { #(#decls)* #(#stmts)* };
373+
#[allow(unreachable_code)]
374+
__ret
375+
}
355376
}
356-
let __ret: #ret = { #(#decls)* #(#stmts)* };
357-
#[allow(unreachable_code)]
358-
__ret
359-
},
377+
}
360378
};
361379
let box_pin = quote_spanned!(block.brace_token.span=>
362380
Box::pin(async move { #let_ret })
@@ -380,6 +398,41 @@ fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
380398
false
381399
}
382400

401+
fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
402+
struct AssociatedTypeImplTraits<'a> {
403+
set: &'a Set<Ident>,
404+
contains: bool,
405+
}
406+
407+
impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
408+
fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
409+
if ty.qself.is_none()
410+
&& ty.path.segments.len() == 2
411+
&& ty.path.segments[0].ident == "Self"
412+
&& self.set.contains(&ty.path.segments[1].ident)
413+
{
414+
self.contains = true;
415+
}
416+
visit_mut::visit_type_path_mut(self, ty);
417+
}
418+
}
419+
420+
match context {
421+
Context::Trait { .. } => false,
422+
Context::Impl {
423+
associated_type_impl_traits,
424+
..
425+
} => {
426+
let mut visit = AssociatedTypeImplTraits {
427+
set: associated_type_impl_traits,
428+
contains: false,
429+
};
430+
visit.visit_type_mut(ret);
431+
visit.contains
432+
}
433+
}
434+
}
435+
383436
fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
384437
clause.get_or_insert_with(|| WhereClause {
385438
where_token: Default::default(),

0 commit comments

Comments
 (0)