Skip to content

Commit 80d91e2

Browse files
committed
Rust: Reimplement type inference for impl Traits and await expressions
1 parent 7e60cfc commit 80d91e2

File tree

8 files changed

+318
-158
lines changed

8 files changed

+318
-158
lines changed

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
private import rust
66
private import codeql.rust.elements.internal.generated.ParentChild
77
private import codeql.rust.internal.CachedStages
8+
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
89

910
private newtype TNamespace =
1011
TTypeNamespace() or
@@ -165,6 +166,8 @@ abstract class ItemNode extends Locatable {
165166
or
166167
// type parameters have access to the associated items of its bounds
167168
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
169+
or
170+
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
168171
}
169172

170173
/**
@@ -618,6 +621,28 @@ class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
618621
}
619622
}
620623

624+
private class ImplTraitTypeReprItemNode extends ItemNode instanceof ImplTraitTypeRepr {
625+
pragma[nomagic]
626+
Path getABoundPath() {
627+
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
628+
}
629+
630+
pragma[nomagic]
631+
ItemNode resolveABound() { result = resolvePathFull(this.getABoundPath()) }
632+
633+
override string getName() { result = "(impl trait)" }
634+
635+
override Namespace getNamespace() { result.isType() }
636+
637+
override Visibility getVisibility() { none() }
638+
639+
override TypeParam getTypeParam(int i) { none() }
640+
641+
override predicate hasCanonicalPath(Crate c) { none() }
642+
643+
override string getCanonicalPath(Crate c) { none() }
644+
}
645+
621646
private class MacroCallItemNode extends AssocItemNode instanceof MacroCall {
622647
override string getName() { result = "(macro call)" }
623648

@@ -1061,14 +1086,21 @@ private predicate crateDefEdge(CrateItemNode c, string name, ItemNode i) {
10611086
not i instanceof Crate
10621087
}
10631088

1089+
private class BuiltinSourceFile extends SourceFileItemNode {
1090+
BuiltinSourceFile() { this.getFile().getParentContainer() instanceof Builtins::BuiltinsFolder }
1091+
}
1092+
10641093
/**
10651094
* Holds if `m` depends on crate `dep` named `name`.
10661095
*/
1096+
pragma[nomagic]
10671097
private predicate crateDependencyEdge(ModuleLikeNode m, string name, CrateItemNode dep) {
1068-
exists(CrateItemNode c |
1069-
dep = c.(Crate).getDependency(name) and
1070-
m = c.getASourceFile()
1071-
)
1098+
exists(CrateItemNode c | dep = c.(Crate).getDependency(name) | m = c.getASourceFile())
1099+
or
1100+
// Give builtin files, such as `await.rs`, access to `std`
1101+
m instanceof BuiltinSourceFile and
1102+
dep.getName() = name and
1103+
name = "std"
10721104
}
10731105

10741106
private predicate useTreeDeclares(UseTree tree, string name) {
@@ -1413,9 +1445,14 @@ private predicate useImportEdge(Use use, string name, ItemNode item) {
14131445
* [1]: https://doc.rust-lang.org/core/prelude/index.html
14141446
* [2]: https://doc.rust-lang.org/std/prelude/index.html
14151447
*/
1448+
pragma[nomagic]
14161449
private predicate preludeEdge(SourceFile f, string name, ItemNode i) {
14171450
exists(Crate stdOrCore, ModuleLikeNode mod, ModuleItemNode prelude, ModuleItemNode rust |
1418-
f = any(Crate c0 | stdOrCore = c0.getDependency(_) or stdOrCore = c0).getASourceFile() and
1451+
f = any(Crate c0 | stdOrCore = c0.getDependency(_) or stdOrCore = c0).getASourceFile()
1452+
or
1453+
// Give builtin files, such as `await.rs`, access to the prelude
1454+
f instanceof BuiltinSourceFile
1455+
|
14191456
stdOrCore.getName() = ["std", "core"] and
14201457
mod = stdOrCore.getSourceFile() and
14211458
prelude = mod.getASuccessorRec("prelude") and
@@ -1425,12 +1462,10 @@ private predicate preludeEdge(SourceFile f, string name, ItemNode i) {
14251462
)
14261463
}
14271464

1428-
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
1429-
14301465
pragma[nomagic]
14311466
private predicate builtin(string name, ItemNode i) {
1432-
exists(SourceFileItemNode builtins |
1433-
builtins.getFile().getParentContainer() instanceof Builtins::BuiltinsFolder and
1467+
exists(BuiltinSourceFile builtins |
1468+
builtins.getFile().getBaseName() = "types.rs" and
14341469
i = builtins.getASuccessorRec(name)
14351470
)
14361471
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@ newtype TType =
1515
TTrait(Trait t) or
1616
TArrayType() or // todo: add size?
1717
TRefType() or // todo: add mut?
18-
TImplTraitType(int bounds) {
19-
bounds = any(ImplTraitTypeRepr impl).getTypeBoundList().getNumberOfBounds()
20-
} or
18+
TImplTraitType(ImplTraitTypeRepr impl) or
2119
TTypeParamTypeParameter(TypeParam t) or
2220
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2321
TRefTypeParameter() or
24-
TSelfTypeParameter(Trait t) or
25-
TImplTraitTypeParameter(ImplTraitType t, int i) { i in [0 .. t.getNumberOfBounds() - 1] }
22+
TSelfTypeParameter(Trait t)
2623

2724
/**
2825
* A type without type arguments.
@@ -184,30 +181,50 @@ class RefType extends Type, TRefType {
184181
}
185182

186183
/**
187-
* An [`impl Trait`][1] type.
184+
* An [impl Trait][1] type.
188185
*
189-
* We represent `impl Trait` types as generic types with as many type parameters
190-
* as there are bounds.
186+
* Each syntactic `impl Trait` type gives rise to its own type, even if
187+
* two `impl Trait` types have the same bounds.
191188
*
192-
* [1] https://doc.rust-lang.org/book/ch10-02-traits.html#traits-as-parameters
189+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html
193190
*/
194191
class ImplTraitType extends Type, TImplTraitType {
195-
private int bounds;
192+
ImplTraitTypeRepr impl;
196193

197-
ImplTraitType() { this = TImplTraitType(bounds) }
194+
ImplTraitType() { this = TImplTraitType(impl) }
198195

199-
/** Gets the number of bounds of this `impl Trait` type. */
200-
int getNumberOfBounds() { result = bounds }
196+
/** Gets the underlying AST node. */
197+
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }
198+
199+
/** Gets the function that this `impl Trait` belongs to. */
200+
abstract Function getFunction();
201201

202202
override StructField getStructField(string name) { none() }
203203

204204
override TupleField getTupleField(int i) { none() }
205205

206-
override TypeParameter getTypeParameter(int i) { result = TImplTraitTypeParameter(this, i) }
206+
override TypeParameter getTypeParameter(int i) { none() }
207207

208-
override string toString() { result = "impl Trait ..." }
208+
override string toString() { result = impl.toString() }
209209

210-
override Location getLocation() { result instanceof EmptyLocation }
210+
override Location getLocation() { result = impl.getLocation() }
211+
}
212+
213+
/**
214+
* An [impl Trait in return position][1] type, for example:
215+
*
216+
* ```rust
217+
* fn foo() -> impl Trait
218+
* ```
219+
*
220+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.return
221+
*/
222+
class ImplTraitReturnType extends ImplTraitType {
223+
private Function function;
224+
225+
ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
226+
227+
override Function getFunction() { result = function }
211228
}
212229

213230
/** A type parameter. */
@@ -316,23 +333,34 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
316333
}
317334

318335
/**
319-
* An `impl Trait` type parameter.
336+
* An [impl Trait in argument position][1] type, for example:
337+
*
338+
* ```rust
339+
* fn foo(arg: impl Trait)
340+
* ```
341+
*
342+
* Such types are syntactic sugar for type parameters, that is
343+
*
344+
* ```rust
345+
* fn foo<T: Trait>(arg: T)
346+
* ```
347+
*
348+
* so we model them as type parameters.
349+
*
350+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.param
320351
*/
321-
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
322-
private ImplTraitType implTraitType;
323-
private int i;
352+
class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {
353+
private Function function;
324354

325-
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTraitType, i) }
355+
ImplTraitTypeTypeParameter() { impl = function.getParamList().getAParam().getTypeRepr() }
326356

327-
/** Gets the `impl Trait` type that this parameter belongs to. */
328-
ImplTraitType getImplTraitType() { result = implTraitType }
357+
override Function getFunction() { result = function }
329358

330-
/** Gets the index of this type parameter. */
331-
int getIndex() { result = i }
359+
override StructField getStructField(string name) { none() }
332360

333-
override string toString() { result = "impl Trait<" + i.toString() + ">" }
361+
override TupleField getTupleField(int i) { none() }
334362

335-
override Location getLocation() { result instanceof EmptyLocation }
363+
override TypeParameter getTypeParameter(int i) { none() }
336364
}
337365

338366
/**
@@ -370,3 +398,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
370398

371399
override TypeParamTypeParameter getATypeParameter() { none() }
372400
}
401+
402+
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
403+
override TypeParamTypeParameter getATypeParameter() { none() }
404+
}

0 commit comments

Comments
 (0)