Skip to content

Commit 3e2bdcf

Browse files
committed
Wasm: Emit fewer null pointer checks.
1 parent 0b94586 commit 3e2bdcf

File tree

3 files changed

+93
-28
lines changed

3 files changed

+93
-28
lines changed

linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ class CoreWasmLib(coreSpec: CoreSpec) {
417417
case DoubleRef => Float64
418418
case _ => Int32
419419
}
420-
addHelperImport(genFunctionID.box(primRef), List(wasmType), List(anyref))
420+
addHelperImport(genFunctionID.box(primRef), List(wasmType), List(RefType.any))
421421
addHelperImport(genFunctionID.unbox(primRef), List(anyref), List(wasmType))
422422
addHelperImport(genFunctionID.typeTest(primRef), List(anyref), List(Int32))
423423
}

linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -409,24 +409,78 @@ private class FunctionEmitter private (
409409
}
410410
}
411411

412-
/** Emits a `ref_as_non_null`, or an NPE check if required. */
413-
private def genAsNonNullOrNPE(): Unit = {
414-
if (semantics.nullPointers == CheckedBehavior.Unchecked)
415-
fb += wa.RefAsNonNull
416-
else
417-
fb += wa.BrOnNull(getNPELabel())
412+
/** Emits a `ref_as_non_null` or an NPE check if required for the given `Tree`.
413+
*
414+
* This method does not emit `tree`. It only uses it to determine whether
415+
* a check is required.
416+
*/
417+
private def genAsNonNullOrNPEFor(tree: Tree): Unit = {
418+
val nullabilityLevel = nullabilityLevelOf(tree)
419+
if (nullabilityLevel >= 1) {
420+
if (semantics.nullPointers != CheckedBehavior.Unchecked && nullabilityLevel >= 2)
421+
fb += wa.BrOnNull(getNPELabel())
422+
else
423+
fb += wa.RefAsNonNull
424+
}
418425
}
419426

420-
/** Emits an NPE check if required, otherwise nothing.
427+
/** Emits an NPE check if required for the given `Tree`, otherwise nothing.
428+
*
429+
* This method does not emit `tree`. It only uses it to determine whether
430+
* a check is required.
421431
*
422432
* Unlike `genAsNonNullOrNPE`, after this codegen the value on the stack is
423433
* still statically typed as nullable at the Wasm level.
424434
*/
425-
private def genCheckNonNull(): Unit = {
426-
if (semantics.nullPointers != CheckedBehavior.Unchecked)
435+
private def genCheckNonNullFor(tree: Tree): Unit = {
436+
if (semantics.nullPointers != CheckedBehavior.Unchecked && nullabilityLevelOf(tree) >= 2)
427437
fb += wa.BrOnNull(getNPELabel())
428438
}
429439

440+
/** Analyzes the nullability level of `tree`.
441+
*
442+
* - `0` if `tree` is statically known to generate a non-nullable Wasm value.
443+
* - `1` if `tree` is statically known to generate a non-nullable value,
444+
* but maybe typed as nullable at the Wasm level.
445+
* - `2` if `tree` can be `null`.
446+
*
447+
* See also: `isNotNull` in `emitter.FunctionEmitter`.
448+
*/
449+
private def nullabilityLevelOf(tree: Tree): Int = {
450+
// !!! Similar code in emitter.FunctionEmitter.isNotNull
451+
// !!! Similar code in OptimizerCore.isNotNull
452+
453+
def isNullableType(tpe: Type): Boolean = tpe match {
454+
case NullType => true
455+
case _: PrimType => false
456+
case _ => true
457+
}
458+
459+
def shapeNullabilityLevel(tree: Tree): Int = tree match {
460+
case Transient(Transients.CheckNotNull(_)) =>
461+
0
462+
case Transient(Transients.AssumeNotNull(expr)) =>
463+
Math.min(1, shapeNullabilityLevel(expr))
464+
case Transient(Transients.Cast(expr, _)) =>
465+
shapeNullabilityLevel(expr)
466+
case _: This =>
467+
if (tree.tpe != AnyType) 0
468+
else 2
469+
case _:New | _:NewArray | _:ArrayValue | _:Clone | _:ClassOf =>
470+
0
471+
case _: LoadModule =>
472+
if (semantics.moduleInit == CheckedBehavior.Compliant) 2
473+
else 0
474+
case _ =>
475+
2
476+
}
477+
478+
if (!isNullableType(tree.tpe))
479+
0
480+
else
481+
shapeNullabilityLevel(tree)
482+
}
483+
430484
/** Emits an unconditional NPE. */
431485
private def genNPE(): Unit = {
432486
if (semantics.nullPointers == CheckedBehavior.Unchecked)
@@ -652,7 +706,7 @@ private class FunctionEmitter private (
652706
markPosition(tree)
653707
genNPE()
654708
} else {
655-
genAsNonNullOrNPE()
709+
genAsNonNullOrNPEFor(qualifier)
656710
genTree(rhs, lhs.tpe)
657711
markPosition(tree)
658712
fb += wa.StructSet(
@@ -686,7 +740,7 @@ private class FunctionEmitter private (
686740
case _ => false
687741
}
688742

689-
genCheckNonNull()
743+
genCheckNonNullFor(array)
690744

691745
if (semantics.arrayIndexOutOfBounds == CheckedBehavior.Unchecked &&
692746
(semantics.arrayStores == CheckedBehavior.Unchecked || isPrimArray)) {
@@ -839,7 +893,7 @@ private class FunctionEmitter private (
839893

840894
// Load receiver and arguments
841895
genTree(receiver, AnyType)
842-
genAsNonNullOrNPE()
896+
genAsNonNullOrNPEFor(receiver)
843897
fb += wa.LocalTee(receiverLocalForDispatch)
844898
genArgs(args, methodName)
845899

@@ -901,7 +955,7 @@ private class FunctionEmitter private (
901955
*/
902956
def genReceiverNotNull(): Unit = {
903957
genTreeAuto(receiver)
904-
genAsNonNullOrNPE()
958+
genAsNonNullOrNPEFor(receiver)
905959
}
906960

907961
/* Generates a resolved call to a method of a hijacked class.
@@ -1180,14 +1234,14 @@ private class FunctionEmitter private (
11801234
BoxedClassToPrimType.get(targetClassName) match {
11811235
case None =>
11821236
genTree(receiver, ClassType(targetClassName))
1183-
genAsNonNullOrNPE()
1237+
genAsNonNullOrNPEFor(receiver)
11841238

11851239
case Some(primReceiverType) =>
11861240
if (receiver.tpe == primReceiverType) {
11871241
genTreeAuto(receiver)
11881242
} else {
11891243
genTree(receiver, AnyType)
1190-
genAsNonNullOrNPE()
1244+
genAsNonNullOrNPEFor(receiver)
11911245
genUnbox(primReceiverType)
11921246
}
11931247
}
@@ -1300,7 +1354,7 @@ private class FunctionEmitter private (
13001354
*/
13011355
genNPE()
13021356
} else {
1303-
genCheckNonNull()
1357+
genCheckNonNullFor(qualifier)
13041358
fb += wa.StructGet(
13051359
genTypeID.forClass(className),
13061360
genFieldID.forClassInstanceField(fieldName)
@@ -2056,9 +2110,15 @@ private class FunctionEmitter private (
20562110
case _ =>
20572111
targetWasmType match {
20582112
case watpe.RefType(true, watpe.HeapType.Any) =>
2059-
() // nothing to do
2113+
if (nullabilityLevelOf(expr) >= 1)
2114+
() // nothing to do
2115+
else
2116+
fb += wa.RefAsNonNull
20602117
case targetWasmType: watpe.RefType =>
2061-
fb += wa.RefCast(targetWasmType)
2118+
if (nullabilityLevelOf(expr) >= 2)
2119+
fb += wa.RefCast(targetWasmType)
2120+
else
2121+
fb += wa.RefCast(targetWasmType.toNonNullable)
20622122
case _ =>
20632123
throw new AssertionError(s"Unexpected type in AsInstanceOf: $targetTpe")
20642124
}
@@ -2082,7 +2142,11 @@ private class FunctionEmitter private (
20822142
fb += wa.GlobalGet(genGlobalID.undef)
20832143

20842144
case StringType =>
2085-
genAsNonNullOrNPE()
2145+
val sig = watpe.FunctionType(List(watpe.RefType.anyref), List(watpe.RefType.any))
2146+
fb.block(sig) { nonNullLabel =>
2147+
fb += wa.BrOnNonNull(nonNullLabel)
2148+
fb += wa.GlobalGet(genGlobalID.emptyString)
2149+
}
20862150

20872151
case targetTpe: PrimTypeWithRef =>
20882152
targetTpe match {
@@ -2139,13 +2203,13 @@ private class FunctionEmitter private (
21392203

21402204
genTreeAuto(expr)
21412205
markPosition(tree)
2142-
genCheckNonNull()
2206+
genCheckNonNullFor(expr)
21432207
fb += wa.StructGet(genTypeID.ObjectStruct, genFieldID.objStruct.vtable)
21442208
fb += wa.Call(genFunctionID.getClassOf)
21452209
} else {
21462210
genTree(expr, AnyType)
21472211
markPosition(tree)
2148-
genAsNonNullOrNPE()
2212+
genAsNonNullOrNPEFor(expr)
21492213
fb += wa.Call(genFunctionID.anyGetClass)
21502214
}
21512215

@@ -2497,7 +2561,7 @@ private class FunctionEmitter private (
24972561

24982562
markPosition(tree)
24992563

2500-
genAsNonNullOrNPE()
2564+
genAsNonNullOrNPEFor(expr)
25012565

25022566
// if !expr.isInstanceOf[js.JavaScriptException], then br $done
25032567
fb += wa.BrOnCastFail(
@@ -2736,7 +2800,7 @@ private class FunctionEmitter private (
27362800
array.tpe match {
27372801
case ArrayType(arrayTypeRef) =>
27382802
// Get the underlying array
2739-
genCheckNonNull()
2803+
genCheckNonNullFor(array)
27402804
fb += wa.StructGet(
27412805
genTypeID.forArrayClass(arrayTypeRef),
27422806
genFieldID.objStruct.arrayUnderlying
@@ -2840,7 +2904,7 @@ private class FunctionEmitter private (
28402904

28412905
array.tpe match {
28422906
case ArrayType(arrayTypeRef) =>
2843-
genCheckNonNull()
2907+
genCheckNonNullFor(array)
28442908

28452909
if (semantics.arrayIndexOutOfBounds == CheckedBehavior.Unchecked) {
28462910
// Get the underlying array
@@ -2998,7 +3062,7 @@ private class FunctionEmitter private (
29983062

29993063
markPosition(tree)
30003064

3001-
genAsNonNullOrNPE()
3065+
genAsNonNullOrNPEFor(expr)
30023066
fb += wa.LocalTee(exprLocal)
30033067

30043068
fb += wa.LocalGet(exprLocal)
@@ -3198,7 +3262,7 @@ private class FunctionEmitter private (
31983262
case tpe: PrimType =>
31993263
tpe
32003264
case tpe =>
3201-
genCheckNonNull()
3265+
genAsNonNullOrNPEFor(expr)
32023266
tpe
32033267
}
32043268

@@ -3215,7 +3279,7 @@ private class FunctionEmitter private (
32153279
case Transients.ObjectClassName(obj) =>
32163280
genTree(obj, AnyType)
32173281
markPosition(tree)
3218-
genAsNonNullOrNPE()
3282+
genAsNonNullOrNPEFor(obj)
32193283
fb += wa.Call(genFunctionID.anyGetClassName)
32203284
StringType
32213285

linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5394,6 +5394,7 @@ private[optimizer] abstract class OptimizerCore(
53945394

53955395
private def isNotNull(tree: Tree): Boolean = {
53965396
// !!! Duplicate code with FunctionEmitter.isNotNull
5397+
// !!! Similar code in wasmemitter.FunctionEmitter.nullabilityLevelOf
53975398

53985399
def isShapeNotNull(tree: Tree): Boolean = tree match {
53995400
case Transient(CheckNotNull(_) | AssumeNotNull(_)) =>

0 commit comments

Comments
 (0)