@@ -3,12 +3,13 @@ use crate::parse::Item;
3
3
use crate :: receiver:: { has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf } ;
4
4
use proc_macro2:: TokenStream ;
5
5
use quote:: { format_ident, quote, quote_spanned, ToTokens } ;
6
+ use std:: collections:: BTreeSet as Set ;
6
7
use syn:: punctuated:: Punctuated ;
7
- use syn:: visit_mut:: VisitMut ;
8
+ use syn:: visit_mut:: { self , VisitMut } ;
8
9
use syn:: {
9
10
parse_quote, Attribute , Block , FnArg , GenericParam , Generics , Ident , ImplItem , Lifetime , Pat ,
10
11
PatIdent , Receiver , ReturnType , Signature , Stmt , Token , TraitItem , Type , TypeParamBound ,
11
- WhereClause ,
12
+ TypePath , WhereClause ,
12
13
} ;
13
14
14
15
macro_rules! parse_quote_spanned {
@@ -34,6 +35,7 @@ enum Context<'a> {
34
35
} ,
35
36
Impl {
36
37
impl_generics : & ' a Generics ,
38
+ associated_type_impl_traits : & ' a Set < Ident > ,
37
39
} ,
38
40
}
39
41
@@ -71,7 +73,7 @@ pub fn expand(input: &mut Item, is_local: bool) {
71
73
method. attrs . push ( parse_quote ! ( #[ must_use] ) ) ;
72
74
if let Some ( block) = block {
73
75
has_self |= has_self_in_block ( block) ;
74
- transform_block ( sig, block) ;
76
+ transform_block ( context , sig, block) ;
75
77
method. attrs . push ( lint_suppress_with_body ( ) ) ;
76
78
} else {
77
79
method. attrs . push ( lint_suppress_without_body ( ) ) ;
@@ -90,16 +92,26 @@ pub fn expand(input: &mut Item, is_local: bool) {
90
92
let elided = lifetimes. elided ;
91
93
input. generics . params = parse_quote ! ( #( #elided, ) * #params) ;
92
94
95
+ let mut associated_type_impl_traits = Set :: new ( ) ;
96
+ for inner in & input. items {
97
+ if let ImplItem :: Type ( assoc) = inner {
98
+ if let Type :: ImplTrait ( _) = assoc. ty {
99
+ associated_type_impl_traits. insert ( assoc. ident . clone ( ) ) ;
100
+ }
101
+ }
102
+ }
103
+
93
104
let context = Context :: Impl {
94
105
impl_generics : & input. generics ,
106
+ associated_type_impl_traits : & associated_type_impl_traits,
95
107
} ;
96
108
for inner in & mut input. items {
97
109
if let ImplItem :: Method ( method) = inner {
98
110
let sig = & mut method. sig ;
99
111
if sig. asyncness . is_some ( ) {
100
112
let block = & mut method. block ;
101
113
let has_self = has_self_in_sig ( sig) || has_self_in_block ( block) ;
102
- transform_block ( sig, block) ;
114
+ transform_block ( context , sig, block) ;
103
115
transform_sig ( context, sig, has_self, false , is_local) ;
104
116
method. attrs . push ( lint_suppress_with_body ( ) ) ;
105
117
}
@@ -296,7 +308,7 @@ fn transform_sig(
296
308
//
297
309
// ___ret
298
310
// })
299
- fn transform_block ( sig : & mut Signature , block : & mut Block ) {
311
+ fn transform_block ( context : Context , sig : & mut Signature , block : & mut Block ) {
300
312
if let Some ( Stmt :: Item ( syn:: Item :: Verbatim ( item) ) ) = block. stmts . first ( ) {
301
313
if block. stmts . len ( ) == 1 && item. to_string ( ) == ";" {
302
314
return ;
@@ -345,18 +357,24 @@ fn transform_block(sig: &mut Signature, block: &mut Block) {
345
357
}
346
358
347
359
let stmts = & block. stmts ;
348
- let let_ret = match & sig. output {
360
+ let let_ret = match & mut sig. output {
349
361
ReturnType :: Default => quote_spanned ! { block. brace_token. span=>
350
362
let _: ( ) = { #( #decls) * #( #stmts) * } ;
351
363
} ,
352
- ReturnType :: Type ( _, ret) => quote_spanned ! { block. brace_token. span=>
353
- if let :: core:: option:: Option :: Some ( __ret) = :: core:: option:: Option :: None :: <#ret> {
354
- return __ret;
364
+ ReturnType :: Type ( _, ret) => {
365
+ if contains_associated_type_impl_trait ( context, ret) {
366
+ quote ! ( #( #decls) * #( #stmts) * )
367
+ } else {
368
+ quote_spanned ! { block. brace_token. span=>
369
+ if let :: core:: option:: Option :: Some ( __ret) = :: core:: option:: Option :: None :: <#ret> {
370
+ return __ret;
371
+ }
372
+ let __ret: #ret = { #( #decls) * #( #stmts) * } ;
373
+ #[ allow( unreachable_code) ]
374
+ __ret
375
+ }
355
376
}
356
- let __ret: #ret = { #( #decls) * #( #stmts) * } ;
357
- #[ allow( unreachable_code) ]
358
- __ret
359
- } ,
377
+ }
360
378
} ;
361
379
let box_pin = quote_spanned ! ( block. brace_token. span=>
362
380
Box :: pin( async move { #let_ret } )
@@ -380,6 +398,41 @@ fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
380
398
false
381
399
}
382
400
401
+ fn contains_associated_type_impl_trait ( context : Context , ret : & mut Type ) -> bool {
402
+ struct AssociatedTypeImplTraits < ' a > {
403
+ set : & ' a Set < Ident > ,
404
+ contains : bool ,
405
+ }
406
+
407
+ impl < ' a > VisitMut for AssociatedTypeImplTraits < ' a > {
408
+ fn visit_type_path_mut ( & mut self , ty : & mut TypePath ) {
409
+ if ty. qself . is_none ( )
410
+ && ty. path . segments . len ( ) == 2
411
+ && ty. path . segments [ 0 ] . ident == "Self"
412
+ && self . set . contains ( & ty. path . segments [ 1 ] . ident )
413
+ {
414
+ self . contains = true ;
415
+ }
416
+ visit_mut:: visit_type_path_mut ( self , ty) ;
417
+ }
418
+ }
419
+
420
+ match context {
421
+ Context :: Trait { .. } => false ,
422
+ Context :: Impl {
423
+ associated_type_impl_traits,
424
+ ..
425
+ } => {
426
+ let mut visit = AssociatedTypeImplTraits {
427
+ set : associated_type_impl_traits,
428
+ contains : false ,
429
+ } ;
430
+ visit. visit_type_mut ( ret) ;
431
+ visit. contains
432
+ }
433
+ }
434
+ }
435
+
383
436
fn where_clause_or_default ( clause : & mut Option < WhereClause > ) -> & mut WhereClause {
384
437
clause. get_or_insert_with ( || WhereClause {
385
438
where_token : Default :: default ( ) ,
0 commit comments