-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Rust: Type inference for operator overloading #19593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6e9a4be
d92d454
5160bc2
6500ebf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,78 @@ | |
private import rust | ||
private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl | ||
|
||
/** | ||
* Holds if the operator `op` is overloaded to a trait with the canonical path | ||
* `path` and the method name `method`. | ||
*/ | ||
private predicate isOverloaded(string op, string path, string method) { | ||
// Negation | ||
op = "-" and path = "core::ops::arith::Neg" and method = "neg" | ||
or | ||
// Not | ||
op = "!" and path = "core::ops::bit::Not" and method = "not" | ||
or | ||
// Dereference | ||
op = "*" and path = "core::ops::Deref" and method = "deref" | ||
or | ||
// Comparison operators | ||
op = "==" and path = "core::cmp::PartialEq" and method = "eq" | ||
or | ||
op = "!=" and path = "core::cmp::PartialEq" and method = "ne" | ||
or | ||
op = "<" and path = "core::cmp::PartialOrd" and method = "lt" | ||
or | ||
op = "<=" and path = "core::cmp::PartialOrd" and method = "le" | ||
or | ||
op = ">" and path = "core::cmp::PartialOrd" and method = "gt" | ||
or | ||
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge" | ||
or | ||
// Arithmetic operators | ||
op = "+" and path = "core::ops::arith::Add" and method = "add" | ||
or | ||
op = "-" and path = "core::ops::arith::Sub" and method = "sub" | ||
or | ||
op = "*" and path = "core::ops::arith::Mul" and method = "mul" | ||
or | ||
op = "/" and path = "core::ops::arith::Div" and method = "div" | ||
or | ||
op = "%" and path = "core::ops::arith::Rem" and method = "rem" | ||
or | ||
// Arithmetic assignment expressions | ||
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign" | ||
or | ||
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign" | ||
or | ||
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign" | ||
or | ||
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign" | ||
or | ||
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign" | ||
or | ||
// Bitwise operators | ||
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand" | ||
or | ||
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor" | ||
or | ||
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor" | ||
or | ||
op = "<<" and path = "core::ops::bit::Shl" and method = "shl" | ||
or | ||
op = ">>" and path = "core::ops::bit::Shr" and method = "shr" | ||
or | ||
// Bitwise assignment operators | ||
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign" | ||
or | ||
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign" | ||
or | ||
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign" | ||
or | ||
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign" | ||
or | ||
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign" | ||
} | ||
|
||
/** | ||
* INTERNAL: This module contains the customizable definition of `Operation` and should not | ||
* be referenced directly. | ||
|
@@ -16,14 +88,28 @@ module Impl { | |
* An operation, for example `&&`, `+=`, `!` or `*`. | ||
*/ | ||
abstract class Operation extends ExprImpl::Expr { | ||
/** Gets the operator name of this operation, if it exists. */ | ||
abstract string getOperatorName(); | ||
|
||
/** Gets the `n`th operand of this operation, if any. */ | ||
abstract Expr getOperand(int n); | ||
|
||
/** | ||
* Gets the operator name of this operation, if it exists. | ||
* Gets the number of operands of this operation. | ||
* | ||
* This is either 1 for prefix operations, or 2 for binary operations. | ||
*/ | ||
abstract string getOperatorName(); | ||
final int getNumberOfOperands() { result = strictcount(this.getAnOperand()) } | ||
|
||
/** Gets an operand of this operation. */ | ||
Expr getAnOperand() { result = this.getOperand(_) } | ||
|
||
/** | ||
* Gets an operand of this operation. | ||
* Holds if this operation is overloaded to the method `methodName` of the | ||
* trait `trait`. | ||
*/ | ||
abstract Expr getAnOperand(); | ||
predicate isOverloaded(Trait trait, string methodName) { | ||
isOverloaded(this.getOperatorName(), trait.getCanonicalPath(), methodName) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great to see that the QL computed canonical paths work here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed 🕺 |
||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -643,20 +643,30 @@ private module CallExprBaseMatchingInput implements MatchingInputSig { | |
|
||
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl | ||
|
||
class Access extends CallExprBase { | ||
abstract class Access extends Expr { | ||
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path); | ||
|
||
abstract AstNode getNodeAt(AccessPosition apos); | ||
|
||
abstract Type getInferredType(AccessPosition apos, TypePath path); | ||
|
||
abstract Declaration getTarget(); | ||
} | ||
|
||
private class CallExprBaseAccess extends Access instanceof CallExprBase { | ||
private TypeMention getMethodTypeArg(int i) { | ||
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i) | ||
} | ||
|
||
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { | ||
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { | ||
exists(TypeMention arg | result = arg.resolveTypeAt(path) | | ||
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam()) | ||
or | ||
arg = this.getMethodTypeArg(apos.asMethodTypeArgumentPosition()) | ||
) | ||
} | ||
|
||
AstNode getNodeAt(AccessPosition apos) { | ||
override AstNode getNodeAt(AccessPosition apos) { | ||
exists(int p, boolean isMethodCall | | ||
argPos(this, result, p, isMethodCall) and | ||
apos = TPositionalAccessPosition(p, isMethodCall) | ||
|
@@ -669,17 +679,42 @@ private module CallExprBaseMatchingInput implements MatchingInputSig { | |
apos = TReturnAccessPosition() | ||
} | ||
|
||
Type getInferredType(AccessPosition apos, TypePath path) { | ||
override Type getInferredType(AccessPosition apos, TypePath path) { | ||
result = inferType(this.getNodeAt(apos), path) | ||
} | ||
|
||
Declaration getTarget() { | ||
override Declaration getTarget() { | ||
result = CallExprImpl::getResolvedFunction(this) | ||
or | ||
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa | ||
} | ||
} | ||
|
||
private class OperationAccess extends Access instanceof Operation { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add charpred with |
||
OperationAccess() { super.isOverloaded(_, _) } | ||
|
||
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { | ||
// The syntax for operators does not allow type arguments. | ||
none() | ||
} | ||
|
||
override AstNode getNodeAt(AccessPosition apos) { | ||
result = super.getOperand(0) and apos = TSelfAccessPosition() | ||
or | ||
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true) | ||
or | ||
result = this and apos = TReturnAccessPosition() | ||
} | ||
|
||
override Type getInferredType(AccessPosition apos, TypePath path) { | ||
result = inferType(this.getNodeAt(apos), path) | ||
} | ||
|
||
override Declaration getTarget() { | ||
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa | ||
} | ||
} | ||
|
||
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) { | ||
apos.isSelf() and | ||
dpos.isSelf() | ||
|
@@ -1059,6 +1094,26 @@ private module MethodCall { | |
pragma[nomagic] | ||
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) } | ||
} | ||
|
||
private class OperationMethodCall extends MethodCallImpl instanceof Operation { | ||
TraitItemNode trait; | ||
string methodName; | ||
|
||
OperationMethodCall() { super.isOverloaded(trait, methodName) } | ||
|
||
override string getMethodName() { result = methodName } | ||
|
||
override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 } | ||
|
||
override Trait getTrait() { result = trait } | ||
|
||
pragma[nomagic] | ||
override Type getTypeAt(TypePath path) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it was already so before this PR, but perhaps call this predicate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I definitely get that, but I can't just change it as we need it to match the signature when it is given to |
||
result = inferType(this.(BinaryExpr).getLhs(), path) | ||
or | ||
result = inferType(this.(PrefixExpr).getExpr(), path) | ||
} | ||
} | ||
} | ||
|
||
import MethodCall | ||
|
Uh oh!
There was an error while loading. Please reload this page.