Skip to content

Rust: Type inference for operator overloading #19593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ module Impl {

override string getOperatorName() { result = Generated::BinaryExpr.super.getOperatorName() }

override Expr getAnOperand() { result = [this.getLhs(), this.getRhs()] }
override Expr getOperand(int n) {
n = 0 and result = this.getLhs()
or
n = 1 and result = this.getRhs()
}
}
}
94 changes: 90 additions & 4 deletions rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,78 @@
private import rust
private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl

/**
* Holds if the operator `op` is overloaded to a trait with the canonical path
* `path` and the method name `method`.
*/
private predicate isOverloaded(string op, string path, string method) {
// Negation
op = "-" and path = "core::ops::arith::Neg" and method = "neg"
or
// Not
op = "!" and path = "core::ops::bit::Not" and method = "not"
or
// Dereference
op = "*" and path = "core::ops::Deref" and method = "deref"
or
// Comparison operators
op = "==" and path = "core::cmp::PartialEq" and method = "eq"
or
op = "!=" and path = "core::cmp::PartialEq" and method = "ne"
or
op = "<" and path = "core::cmp::PartialOrd" and method = "lt"
or
op = "<=" and path = "core::cmp::PartialOrd" and method = "le"
or
op = ">" and path = "core::cmp::PartialOrd" and method = "gt"
or
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge"
or
// Arithmetic operators
op = "+" and path = "core::ops::arith::Add" and method = "add"
or
op = "-" and path = "core::ops::arith::Sub" and method = "sub"
or
op = "*" and path = "core::ops::arith::Mul" and method = "mul"
or
op = "/" and path = "core::ops::arith::Div" and method = "div"
or
op = "%" and path = "core::ops::arith::Rem" and method = "rem"
or
// Arithmetic assignment expressions
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign"
or
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign"
or
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign"
or
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign"
or
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign"
or
// Bitwise operators
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand"
or
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor"
or
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor"
or
op = "<<" and path = "core::ops::bit::Shl" and method = "shl"
or
op = ">>" and path = "core::ops::bit::Shr" and method = "shr"
or
// Bitwise assignment operators
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign"
or
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign"
or
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign"
or
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign"
or
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign"
}

/**
* INTERNAL: This module contains the customizable definition of `Operation` and should not
* be referenced directly.
Expand All @@ -16,14 +88,28 @@ module Impl {
* An operation, for example `&&`, `+=`, `!` or `*`.
*/
abstract class Operation extends ExprImpl::Expr {
/** Gets the operator name of this operation, if it exists. */
abstract string getOperatorName();

/** Gets the `n`th operand of this operation, if any. */
abstract Expr getOperand(int n);

/**
* Gets the operator name of this operation, if it exists.
* Gets the number of operands of this operation.
*
* This is either 1 for prefix operations, or 2 for binary operations.
*/
abstract string getOperatorName();
final int getNumberOfOperands() { result = strictcount(this.getAnOperand()) }

/** Gets an operand of this operation. */
Expr getAnOperand() { result = this.getOperand(_) }

/**
* Gets an operand of this operation.
* Holds if this operation is overloaded to the method `methodName` of the
* trait `trait`.
*/
abstract Expr getAnOperand();
predicate isOverloaded(Trait trait, string methodName) {
isOverloaded(this.getOperatorName(), trait.getCanonicalPath(), methodName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see that the QL computed canonical paths work here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed 🕺

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ module Impl {

override string getOperatorName() { result = Generated::PrefixExpr.super.getOperatorName() }

override Expr getAnOperand() { result = this.getExpr() }
override Expr getOperand(int n) { n = 0 and result = this.getExpr() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ module Impl {

override string getOperatorName() { result = "&" }

override Expr getAnOperand() { result = this.getExpr() }
override Expr getOperand(int n) { n = 0 and result = this.getExpr() }

private string getSpecPart(int index) {
index = 0 and this.isRaw() and result = "raw"
Expand Down
65 changes: 60 additions & 5 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -643,20 +643,30 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {

private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl

class Access extends CallExprBase {
abstract class Access extends Expr {
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);

abstract AstNode getNodeAt(AccessPosition apos);

abstract Type getInferredType(AccessPosition apos, TypePath path);

abstract Declaration getTarget();
}

private class CallExprBaseAccess extends Access instanceof CallExprBase {
private TypeMention getMethodTypeArg(int i) {
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i)
}

Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
or
arg = this.getMethodTypeArg(apos.asMethodTypeArgumentPosition())
)
}

AstNode getNodeAt(AccessPosition apos) {
override AstNode getNodeAt(AccessPosition apos) {
exists(int p, boolean isMethodCall |
argPos(this, result, p, isMethodCall) and
apos = TPositionalAccessPosition(p, isMethodCall)
Expand All @@ -669,17 +679,42 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
apos = TReturnAccessPosition()
}

Type getInferredType(AccessPosition apos, TypePath path) {
override Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}

Declaration getTarget() {
override Declaration getTarget() {
result = CallExprImpl::getResolvedFunction(this)
or
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}

private class OperationAccess extends Access instanceof Operation {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add charpred with super.isOverloaded(_, _)?

OperationAccess() { super.isOverloaded(_, _) }

override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
// The syntax for operators does not allow type arguments.
none()
}

override AstNode getNodeAt(AccessPosition apos) {
result = super.getOperand(0) and apos = TSelfAccessPosition()
or
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true)
or
result = this and apos = TReturnAccessPosition()
}

override Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}

override Declaration getTarget() {
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}

predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
apos.isSelf() and
dpos.isSelf()
Expand Down Expand Up @@ -1059,6 +1094,26 @@ private module MethodCall {
pragma[nomagic]
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
}

private class OperationMethodCall extends MethodCallImpl instanceof Operation {
TraitItemNode trait;
string methodName;

OperationMethodCall() { super.isOverloaded(trait, methodName) }

override string getMethodName() { result = methodName }

override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 }

override Trait getTrait() { result = trait }

pragma[nomagic]
override Type getTypeAt(TypePath path) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it was already so before this PR, but perhaps call this predicate getReceiverTypeAt instead? I have mistakenly thought it meant the type of the method call itself a couple of times...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely get that, but I can't just change it as we need it to match the signature when it is given to IsInstantiationOf (which is generic and isn't just used for receivers).

result = inferType(this.(BinaryExpr).getLhs(), path)
or
result = inferType(this.(PrefixExpr).getExpr(), path)
}
}
}

import MethodCall
Expand Down
Loading