Skip to content

Enable the IR checker post optimizer with RT longs #5077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) {
new Analyzer(config, initial = false, checkIRFor, failOnError = true, irLoader)
}

/* TODO: Remove this and replace with `checkIR` once the optimizer generates
* well-typed IR with runtime longs.
*/
private val shouldRunIRChecker = {
val optimizerUsesRuntimeLong =
!config.coreSpec.esFeatures.allowBigIntsForLongs &&
!config.coreSpec.targetIsWebAssembly
checkIR && !optimizerUsesRuntimeLong
}

def refine(classDefs: Seq[(ClassDef, Version)],
moduleInitializers: List[ModuleInitializer],
symbolRequirements: SymbolRequirement, logger: Logger)(
Expand Down Expand Up @@ -81,7 +71,7 @@ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) {
linkedTopLevelExports.flatten.toList, moduleInitializers, globalInfo)
}

if (shouldRunIRChecker) {
if (checkIR) {
logger.time("Refiner: Check IR") {
val errorCount = IRChecker.check(linkTimeProperties, result, logger,
CheckingPhase.Optimizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,29 +263,8 @@ private[optimizer] abstract class OptimizerCore(

private val isSubclassFun = isSubclass _

private def isSubtype(lhs: Type, rhs: Type): Boolean = {
assert(lhs != VoidType)
assert(rhs != VoidType)

Types.isSubtype(lhs, rhs)(isSubclassFun) || {
(lhs, rhs) match {
case (LongType, ClassType(LongImpl.RuntimeLongClass, _)) =>
true
case (ClassType(LongImpl.RuntimeLongClass, false), LongType) =>
true
case (ClassType(BoxedLongClass, lhsNullable),
ClassType(LongImpl.RuntimeLongClass, rhsNullable)) =>
rhsNullable || !lhsNullable

case (ClassType(LongImpl.RuntimeLongClass, lhsNullable),
ClassType(BoxedLongClass, rhsNullable)) =>
rhsNullable || !lhsNullable

case _ =>
false
}
}
}
private def isSubtype(lhs: Type, rhs: Type): Boolean =
Types.isSubtype(lhs, rhs)(isSubclassFun)

/** Transforms a statement.
*
Expand Down Expand Up @@ -577,8 +556,16 @@ private[optimizer] abstract class OptimizerCore(
case IsInstanceOf(expr, testType) =>
trampoline {
pretransformExpr(expr) { texpr =>
val texprType = texpr.tpe.base.toNonNullable

// Note: Disregards nullability because we can optimize null-check only.
val staticSubtype = {
isSubtype(texprType, testType) ||
(useRuntimeLong && isRTLong(testType) && isRTLong(texprType))
}

val result = {
if (isSubtype(texpr.tpe.base.toNonNullable, testType)) {
if (staticSubtype) {
if (texpr.tpe.isNullable)
BinaryOp(BinaryOp.!==, finishTransformExpr(texpr), Null())
else
Expand Down Expand Up @@ -762,10 +749,23 @@ private[optimizer] abstract class OptimizerCore(
def addCaptureParam(newName: LocalName): LocalDef = {
val newOriginalName = originalNameForFresh(paramName, originalName, newName)

val captureTpe = {
/* Do not refine the capture type for longs:
* The pretransform might be a stack allocated RuntimeLong.
* We cannot (trivially) capture it in stack allocated form.
* Therefore, we keep the primitive type and let finishTransformExpr
* allocate a RuntimeLong.
*
* TODO: Improve this and allocate two capture params for lo/hi?
*/
if (useRuntimeLong && paramDef.ptpe == LongType) RefinedType(LongType)
else tcaptureValue.tpe
}

val replacement = ReplaceWithVarRef(newName, newSimpleState(Unused))
val localDef = LocalDef(tcaptureValue.tpe, mutable, replacement)
val localDef = LocalDef(captureTpe, mutable, replacement)
val localIdent = LocalIdent(newName)(ident.pos)
val newParamDef = ParamDef(localIdent, newOriginalName, tcaptureValue.tpe.base, mutable)(paramDef.pos)
val newParamDef = ParamDef(localIdent, newOriginalName, captureTpe.base, mutable)(paramDef.pos)

/* Note that the binding will never create a fresh name for a
* ReplaceWithVarRef. So this will not put our name alignment at risk.
Expand Down Expand Up @@ -1297,12 +1297,22 @@ private[optimizer] abstract class OptimizerCore(
}

if (lhsStructure.className == LongImpl.RuntimeLongClass && trhs.tpe.base == LongType) {
/* The lhs is a stack-allocated RuntimeLong, but the rhs is a
* primitive Long. We expand the primitive Long into a new
* stack-allocated RuntimeLong so that we do not need to cancel.
*/
expandLongValue(trhs) { expandedRhs =>
buildInner(expandedRhs)
// The lhs is a stack-allocated RuntimeLong, the rhs is *typed* as primitive long.

trhs match {
case PreTransCast(trhs: PreTransRecordTree, _) =>
/* The rhs is also a stack allocated Long but was cast back to
* a primitive Long (due to method inlining). Remove the cast.
*/
buildInner(trhs)

case _ =>
/* The rhs is a primitive Long. We expand the primitive Long into
* a new stack-allocated RuntimeLong so that we do not need to cancel.
*/
expandLongValue(trhs) { expandedRhs =>
buildInner(expandedRhs)
}
}
} else {
buildInner(trhs)
Expand Down Expand Up @@ -1337,6 +1347,18 @@ private[optimizer] abstract class OptimizerCore(
private def resolveLocalDef(preTrans: PreTransform): PreTransGenTree = {
implicit val pos = preTrans.pos
preTrans match {
case PreTransCast(inner, refinedType) =>
resolveLocalDef(inner) match {
case tree: PreTransRecordTree =>
/* The call site will have to inspect the structure of the record
* tree. Therefore, dropping the cast here is fine.
*/
tree

case PreTransTree(innerTree, _) =>
PreTransTree(makeCast(innerTree, refinedType.base), refinedType)
}

case PreTransBlock(bindingsAndStats, result) =>
resolveLocalDef(result) match {
case PreTransRecordTree(tree, structure, cancelFun) =>
Expand Down Expand Up @@ -1382,6 +1404,9 @@ private[optimizer] abstract class OptimizerCore(
private def resolveRecordStructure(
preTrans: PreTransform): Option[(InlineableClassStructure, CancelFun)] = {
preTrans match {
case PreTransCast(preTrans, _) =>
resolveRecordStructure(preTrans)

case PreTransBlock(_, result) =>
resolveRecordStructure(result)

Expand Down Expand Up @@ -1441,6 +1466,8 @@ private[optimizer] abstract class OptimizerCore(
UnaryOp(op, finishTransformExpr(lhs))
case PreTransBinaryOp(op, lhs, rhs) =>
BinaryOp(op, finishTransformExpr(lhs), finishTransformExpr(rhs))
case PreTransCast(expr, refinedType) =>
makeCast(finishTransformExpr(expr), refinedType.base)
case PreTransLocalDef(localDef) =>
localDef.newReplacement

Expand Down Expand Up @@ -1531,6 +1558,8 @@ private[optimizer] abstract class OptimizerCore(
finishNoSideEffects
}

case PreTransCast(expr, _) =>
finishTransformStat(expr)
case PreTransLocalDef(_) =>
Skip()(stat.pos)
case PreTransRecordTree(tree, _, _) =>
Expand Down Expand Up @@ -3056,14 +3085,9 @@ private[optimizer] abstract class OptimizerCore(

case ClassGetName =>
optTReceiver.get match {
case PreTransMaybeBlock(bindingsAndStats,
PreTransTree(MaybeCast(UnaryOp(UnaryOp.GetClass, expr)), _)) =>
contTree(finishTransformBindings(
bindingsAndStats, Transient(ObjectClassName(expr))))

// Same thing, but the argument stayed as a PreTransUnaryOp
case PreTransMaybeBlock(bindingsAndStats,
PreTransUnaryOp(UnaryOp.GetClass, texpr)) =>
MaybeCast(PreTransUnaryOp(UnaryOp.GetClass, texpr))) =>
contTree(finishTransformBindings(
bindingsAndStats, Transient(ObjectClassName(finishTransformExpr(texpr)))))

Expand Down Expand Up @@ -5277,7 +5301,16 @@ private[optimizer] abstract class OptimizerCore(
def mayRequireUnboxing: Boolean =
arg.tpe.isNullable && tpe.isInstanceOf[PrimType]

if (semantics.asInstanceOfs == CheckedBehavior.Unchecked && !mayRequireUnboxing)
/* In methods on RuntimeLong, we often asInstanceOf Long to RuntimeLong and
* vice versa. We know that these are the same at runtime, so we lower to casts.
*/
val castForRTLong: Boolean = useRuntimeLong && {
val vtpe = arg.tpe.base
(!vtpe.isNullable || tpe.isNullable) &&
isRTLong(arg.tpe.base) && isRTLong(tpe)
}

if (semantics.asInstanceOfs == CheckedBehavior.Unchecked && !mayRequireUnboxing || castForRTLong)
foldCast(arg, tpe)
else if (isSubtype(arg.tpe.base, tpe))
arg
Expand All @@ -5288,38 +5321,24 @@ private[optimizer] abstract class OptimizerCore(
private def foldCast(arg: PreTransform, tpe: Type)(
implicit pos: Position): PreTransform = {

def default(arg: PreTransform, newTpe: RefinedType): PreTransform =
PreTransTree(makeCast(finishTransformExpr(arg), newTpe.base), newTpe)

def castLocalDef(arg: PreTransform, newTpe: RefinedType): PreTransform = arg match {
case PreTransMaybeBlock(bindingsAndStats, PreTransLocalDef(localDef)) =>
val refinedLocalDef = localDef.tryWithRefinedType(newTpe)
if (refinedLocalDef ne localDef)
PreTransBlock(bindingsAndStats, PreTransLocalDef(refinedLocalDef))
else
default(arg, newTpe)

case _ =>
default(arg, newTpe)
}

if (isSubtype(arg.tpe.base, tpe)) {
arg
} else {
lazy val castTpe = {
val tpe1 =
if (arg.tpe.isNullable) tpe
else tpe.toNonNullable
RefinedType(tpe1, isExact = false, arg.tpe.allocationSite)
}

val castTpe = RefinedType(tpe1, isExact = false, arg.tpe.allocationSite)
arg match {
case PreTransCast(arg, _) =>
// Replace existing cast.
foldCast(arg, tpe)

val isCastFreeAtRunTime = tpe != CharType
case arg if isSubtype(arg.tpe.base, tpe) =>
// Cast is redundant.
arg

if (isCastFreeAtRunTime) {
// Try to push the cast down to usages of LocalDefs, in order to preserve aliases
castLocalDef(arg, castTpe)
} else {
default(arg, castTpe)
}
case _ =>
PreTransCast(arg, castTpe)
}
}

Expand Down Expand Up @@ -5722,6 +5741,14 @@ private[optimizer] abstract class OptimizerCore(
*/
buildInner(localDef, cont)

case PreTransCast(PreTransLocalDef(
localDef @ LocalDef(_, mutable, replacement)), refinedType)
if !mutable && refinedType.base != CharType =>
// Casts to Char are not free, so we do not want to duplicate them.
val newLocalDef = LocalDef(refinedType, mutable,
ReplaceWithOtherLocalDef(localDef))
buildInner(newLocalDef, cont)

case PreTransTree(literal: Literal, _) =>
buildInner(LocalDef(value.tpe, false,
ReplaceWithConstant(literal)), cont)
Expand Down Expand Up @@ -5817,6 +5844,16 @@ private[optimizer] abstract class OptimizerCore(
else upperBound
}

/** Whether the given type is a RuntimeLong long at runtime.
*
* Assumes useRuntimeLong.
*/
private def isRTLong(tpe: Type) = tpe match {
case LongType => true
case ClassType(LongImpl.RuntimeLongClass | BoxedLongClass, _) => true
case _ => false
}

/** Trampolines a pretransform */
private def trampoline(tailrec: => TailRec[Tree]): Tree = {
// scalastyle:off return
Expand Down Expand Up @@ -6168,22 +6205,6 @@ private[optimizer] object OptimizerCore {
false
})
}

def tryWithRefinedType(refinedType: RefinedType): LocalDef = {
/* Only adjust if the replacement if ReplaceWithVarRef, because other
* types have nothing to gain (e.g., ReplaceWithConstant) or we want to
* keep them unwrapped because they are examined in optimizations
* (notably all the types with virtualized objects).
*/
replacement match {
case _:ReplaceWithVarRef =>
LocalDef(refinedType, mutable, ReplaceWithOtherLocalDef(this))
case replacement: ReplaceWithOtherLocalDef =>
LocalDef(refinedType, mutable, replacement)
case _ =>
this
}
}
}

private sealed abstract class LocalDefReplacement
Expand Down Expand Up @@ -6342,6 +6363,8 @@ private[optimizer] object OptimizerCore {
lhs.contains(localDef)
case PreTransBinaryOp(_, lhs, rhs) =>
lhs.contains(localDef) || rhs.contains(localDef)
case PreTransCast(expr, _) =>
expr.contains(localDef)
case PreTransLocalDef(thisLocalDef) =>
thisLocalDef.contains(localDef)
case _: PreTransGenTree =>
Expand Down Expand Up @@ -6482,6 +6505,9 @@ private[optimizer] object OptimizerCore {
val tpe: RefinedType = RefinedType(BinaryOp.resultTypeOf(op))
}

private final case class PreTransCast(expr: PreTransform, tpe: RefinedType)(
implicit val pos: Position) extends PreTransResult

/** A virtual reference to a `LocalDef`. */
private final case class PreTransLocalDef(localDef: LocalDef)(
implicit val pos: Position) extends PreTransResult {
Expand Down Expand Up @@ -6690,8 +6716,8 @@ private[optimizer] object OptimizerCore {
private def createNewLong(lo: Tree, hi: Tree)(
implicit pos: Position): Tree = {

New(LongImpl.RuntimeLongClass, MethodIdent(LongImpl.initFromParts),
List(lo, hi))
makeCast(New(LongImpl.RuntimeLongClass, MethodIdent(LongImpl.initFromParts),
List(lo, hi)), LongType)
}

/** Tests whether `x + y` is valid without falling out of range. */
Expand Down Expand Up @@ -7068,6 +7094,11 @@ private[optimizer] object OptimizerCore {
case Transient(Cast(inner, _)) => Some(inner)
case _ => Some(tree)
}

def unapply(tree: PreTransform): Some[PreTransform] = tree match {
case PreTransCast(inner, _) => Some(inner)
case _ => Some(tree)
}
}

private val TraitInitSimpleMethodName = SimpleMethodName("$init$")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,18 +490,10 @@ object IRCheckerTest {
moduleInitializers: List[ModuleInitializer],
logger: Logger, postOptimizer: Boolean)(
implicit ec: ExecutionContext): Future[Unit] = {
val baseConfig = StandardConfig()
val config = StandardConfig()
.withCheckIR(true)
.withOptimizer(false)

val config = {
/* Disable RuntimeLongs to workaround the Refiner disabling IRChecks in this case.
* TODO: Remove once we run IRChecks post optimizer all the time.
*/
if (postOptimizer) baseConfig.withESFeatures(_.withAllowBigIntsForLongs(true))
else baseConfig
}

val noSymbolRequirements = SymbolRequirement
.factory("IRCheckerTest")
.none()
Expand Down
Loading