Skip to content

Commit 47498f2

Browse files
committed
Fix #5208: Introduce raw floating point bit manipulation.
We repurpose the previous opcodes to mean the raw variants. They do not provide any guarantee for the NaN bit patterns (other than being one of the NaN patterns, obviously). We now implement the canonicalizing variants in user-space on top of the raw variants. On Wasm, we had to do that anyway, except the "user-space" was in the function emitter. On JavaScript, despite the spec "suggesting" that NaN's be canonicalized, real-world engines do not canonicalize all NaN's. Doing it in user-space is the only reliable way to guarantee our spec. --- We replace usages of the canonicalizing variants by raw variants in the low-level javalib methods where we can. Either they are in code paths where NaN's have been excluded, or where it does not actually matter what bit pattern we receive for NaN's. This allows not to regress on performance and code size for some low-level methods. Unfortunately, `Double.hashCode` does require one more branch, but that is inevitable to guarantee the correct semantics. `ByteBuffer` methods `putFloat` and `putDouble` do not specify the bit patterns of `NaN`s, and experimentally, I was able to observe non-canonical bit patterns on the JVM. So in that case, we also use the raw variants, which is consistent with using the `DataView` methods in `TypedArray`-backed byte buffers.
1 parent b9b6a6d commit 47498f2

File tree

16 files changed

+219
-71
lines changed

16 files changed

+219
-71
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7486,11 +7486,11 @@ private object GenJSCode {
74867486
m("numberOfLeadingZeros", List(J), I) -> ArgUnaryOp(unop.Long_clz)
74877487
),
74887488
jswkn.BoxedFloatClass.withSuffix("$") -> Map(
7489-
m("floatToIntBits", List(F), I) -> ArgUnaryOp(unop.Float_toBits),
7489+
m("floatToRawIntBits", List(F), I) -> ArgUnaryOp(unop.Float_toBits),
74907490
m("intBitsToFloat", List(I), F) -> ArgUnaryOp(unop.Float_fromBits)
74917491
),
74927492
jswkn.BoxedDoubleClass.withSuffix("$") -> Map(
7493-
m("doubleToLongBits", List(D), J) -> ArgUnaryOp(unop.Double_toBits),
7493+
m("doubleToRawLongBits", List(D), J) -> ArgUnaryOp(unop.Double_toBits),
74947494
m("longBitsToDouble", List(J), D) -> ArgUnaryOp(unop.Double_fromBits)
74957495
),
74967496
jswkn.BoxedStringClass -> Map(

ir/shared/src/main/scala/org/scalajs/ir/Trees.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,17 +510,15 @@ object Trees {
510510
final val Throw = 31
511511

512512
// Floating point bit manipulation, introduced in 1.20
513-
final val Float_toBits = 32
514-
// final val Float_toRawBits = 33 // Reserved
515-
final val Float_fromBits = 34
516-
final val Double_toBits = 35
517-
// final val Double_toRawBits = 36 // Reserved
518-
final val Double_fromBits = 37
513+
final val Float_toBits = 32 // (this is the raw version, without any guarantee for NaN bit patterns)
514+
final val Float_fromBits = 33
515+
final val Double_toBits = 34 // (this is the raw version, without any guarantee for NaN bit patterns)
516+
final val Double_fromBits = 35
519517

520518
// Other nodes introduced in 1.20
521-
final val Int_clz = 38
522-
final val Long_clz = 39
523-
final val UnsignedIntToLong = 40
519+
final val Int_clz = 36
520+
final val Long_clz = 37
521+
final val UnsignedIntToLong = 38
524522

525523
def isClassOp(op: Code): Boolean =
526524
op >= Class_name && op <= Class_superClass

javalib/src/main/scala/java/io/DataOutputStream.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ class DataOutputStream(out: OutputStream)
6262
}
6363

6464
final def writeFloat(v: Float): Unit =
65-
writeInt(java.lang.Float.floatToIntBits(v))
65+
writeInt(java.lang.Float.floatToIntBits(v)) // must canonicalize NaNs
6666

6767
final def writeDouble(v: Double): Unit =
68-
writeLong(java.lang.Double.doubleToLongBits(v))
68+
writeLong(java.lang.Double.doubleToLongBits(v)) // must canonicalize NaNs
6969

7070
final def writeBytes(s: String): Unit = {
7171
for (i <- 0 until s.length())

javalib/src/main/scala/java/lang/Double.scala

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ object Double {
7474
final val SIZE = 64
7575
final val BYTES = 8
7676

77+
private final val PosInfinityBits = 0x7ff0000000000000L
78+
private final val CanonicalNaNBits = 0x7ff8000000000000L
79+
7780
@inline def `new`(value: scala.Double): Double = valueOf(value)
7881

7982
@inline def `new`(s: String): Double = valueOf(s)
@@ -280,7 +283,7 @@ object Double {
280283
val mbits = 52 // mantissa size
281284
val bias = (1 << (ebits - 1)) - 1
282285

283-
val bits = doubleToLongBits(d)
286+
val bits = doubleToRawLongBits(d)
284287
val s = bits < 0
285288
val m = bits & ((1L << mbits) - 1L)
286289
val e = (bits >>> mbits).toInt & ((1 << ebits) - 1) // biased
@@ -388,10 +391,12 @@ object Double {
388391

389392
@inline
390393
private def hashCodeForWasm(value: scala.Double): Int = {
391-
val bits = doubleToLongBits(value)
394+
val bits = doubleToRawLongBits(value)
392395
val valueInt = value.toInt
393-
if (doubleToLongBits(valueInt.toDouble) == bits)
396+
if (doubleToRawLongBits(valueInt.toDouble) == bits)
394397
valueInt
398+
else if (isNaNBitPattern(bits))
399+
Long.hashCode(CanonicalNaNBits)
395400
else
396401
Long.hashCode(bits)
397402
}
@@ -401,16 +406,41 @@ object Double {
401406
val valueInt = (value.asInstanceOf[js.Dynamic] | 0.asInstanceOf[js.Dynamic]).asInstanceOf[Int]
402407
if (valueInt.toDouble == value && 1.0/value != scala.Double.NegativeInfinity)
403408
valueInt
409+
else if (value != value)
410+
Long.hashCode(CanonicalNaNBits)
404411
else
405-
Long.hashCode(doubleToLongBits(value))
412+
Long.hashCode(doubleToRawLongBits(value))
406413
}
407414

408415
@inline def longBitsToDouble(bits: scala.Long): scala.Double =
409416
throw new Error("stub") // body replaced by the compiler back-end
410417

411-
@inline def doubleToLongBits(value: scala.Double): scala.Long =
418+
@inline def doubleToRawLongBits(value: scala.Double): scala.Long =
412419
throw new Error("stub") // body replaced by the compiler back-end
413420

421+
@inline def doubleToLongBits(value: scala.Double): scala.Long = {
422+
if (LinkingInfo.isWebAssembly) {
423+
val rawBits = doubleToRawLongBits(value)
424+
if (isNaNBitPattern(rawBits))
425+
CanonicalNaNBits
426+
else
427+
rawBits
428+
} else {
429+
/* On JS, the Long comparison inside isNaNBitPattern is expensive.
430+
* We compare to NaN at the double level instead.
431+
*/
432+
if (value != value)
433+
CanonicalNaNBits
434+
else
435+
doubleToRawLongBits(value)
436+
}
437+
}
438+
439+
@inline private def isNaNBitPattern(bits: scala.Long): scala.Boolean = {
440+
// Both operands are non-negative; it does not matter whether the comparison is signed or not
441+
(bits & ~scala.Long.MinValue) > PosInfinityBits
442+
}
443+
414444
@inline def sum(a: scala.Double, b: scala.Double): scala.Double =
415445
a + b
416446

javalib/src/main/scala/java/lang/Float.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ object Float {
7272
final val SIZE = 32
7373
final val BYTES = 4
7474

75+
private final val PosInfinityBits = 0x7f800000
76+
private final val CanonicalNaNBits = 0x7fc00000
77+
7578
@inline def `new`(value: scala.Float): Float = valueOf(value)
7679

7780
@inline def `new`(value: scala.Double): Float = valueOf(value.toFloat)
@@ -263,7 +266,7 @@ object Float {
263266
val kbits = 11 // number of bits of the exponent
264267
val bias = (1 << (kbits - 1)) - 1 // the bias of the exponent
265268

266-
val midBits = Double.doubleToLongBits(mid)
269+
val midBits = Double.doubleToRawLongBits(mid)
267270
val biasedK = (midBits >> mbits).toInt
268271

269272
/* Because `mid` is a double value halfway between two floats, it cannot
@@ -300,7 +303,7 @@ object Float {
300303
zDown
301304
else if (cmp > 0)
302305
zUp
303-
else if ((floatToIntBits(zDown) & 1) == 0) // zDown is even
306+
else if ((floatToRawIntBits(zDown) & 1) == 0) // zDown is even
304307
zDown
305308
else
306309
zUp
@@ -430,9 +433,22 @@ object Float {
430433
@inline def intBitsToFloat(bits: scala.Int): scala.Float =
431434
throw new Error("stub") // body replaced by the compiler back-end
432435

433-
@inline def floatToIntBits(value: scala.Float): scala.Int =
436+
@inline def floatToRawIntBits(value: scala.Float): scala.Int =
434437
throw new Error("stub") // body replaced by the compiler back-end
435438

439+
@inline def floatToIntBits(value: scala.Float): scala.Int = {
440+
val rawBits = floatToRawIntBits(value)
441+
if (isNaNBitPattern(rawBits))
442+
CanonicalNaNBits
443+
else
444+
rawBits
445+
}
446+
447+
@inline private def isNaNBitPattern(bits: scala.Int): scala.Boolean = {
448+
// Both operands are non-negative; it does not matter whether the comparison is signed or not
449+
(bits & ~Int.MinValue) > PosInfinityBits
450+
}
451+
436452
@inline def sum(a: scala.Float, b: scala.Float): scala.Float =
437453
a + b
438454

javalib/src/main/scala/java/lang/Math.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ object Math {
178178
} else if (a == -0.0) { // also matches +0.0 but that's fine
179179
scala.Double.MinPositiveValue
180180
} else {
181-
val abits = Double.doubleToLongBits(a)
181+
val abits = Double.doubleToRawLongBits(a)
182182
val rbits = if (a > 0) abits + 1L else abits - 1L
183183
Double.longBitsToDouble(rbits)
184184
}
@@ -190,7 +190,7 @@ object Math {
190190
} else if (a == -0.0f) { // also matches +0.0f but that's fine
191191
scala.Float.MinPositiveValue
192192
} else {
193-
val abits = Float.floatToIntBits(a)
193+
val abits = Float.floatToRawIntBits(a)
194194
val rbits = if (a > 0) abits + 1 else abits - 1
195195
Float.intBitsToFloat(rbits)
196196
}
@@ -202,7 +202,7 @@ object Math {
202202
} else if (a == 0.0) { // also matches -0.0 but that's fine
203203
-scala.Double.MinPositiveValue
204204
} else {
205-
val abits = Double.doubleToLongBits(a)
205+
val abits = Double.doubleToRawLongBits(a)
206206
val rbits = if (a > 0) abits - 1L else abits + 1L
207207
Double.longBitsToDouble(rbits)
208208
}
@@ -214,7 +214,7 @@ object Math {
214214
} else if (a == 0.0f) { // also matches -0.0f but that's fine
215215
-scala.Float.MinPositiveValue
216216
} else {
217-
val abits = Float.floatToIntBits(a)
217+
val abits = Float.floatToRawIntBits(a)
218218
val rbits = if (a > 0) abits - 1 else abits + 1
219219
Float.intBitsToFloat(rbits)
220220
}

javalib/src/main/scala/java/math/BigDecimal.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ class BigDecimal() extends Number with Comparable[BigDecimal] {
504504
if (JDouble.isInfinite(dVal) || JDouble.isNaN(dVal))
505505
throw new NumberFormatException("Infinity or NaN: " + dVal)
506506

507-
val bits = java.lang.Double.doubleToLongBits(dVal)
507+
val bits = java.lang.Double.doubleToRawLongBits(dVal)
508508
// Extracting the exponent, note that the bias is 1023
509509
_scale = 1075 - ((bits >> 52) & 2047).toInt
510510
// Extracting the 52 bits of the mantissa.

javalib/src/main/scala/java/nio/ByteArrayBits.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ private[nio] final class ByteArrayBits(
170170

171171
@inline
172172
private def unmakeFloat(f: Float): (Byte, Byte, Byte, Byte) =
173-
unmakeInt(java.lang.Float.floatToIntBits(f))
173+
unmakeInt(java.lang.Float.floatToRawIntBits(f)) // NaN bit patterns are unspecified here
174174

175175
@inline
176176
private def unmakeDouble(
177177
d: Double): (Byte, Byte, Byte, Byte, Byte, Byte, Byte, Byte) =
178-
unmakeLong(java.lang.Double.doubleToLongBits(d))
178+
unmakeLong(java.lang.Double.doubleToRawLongBits(d)) // NaN bit patterns are unspecified here
179179

180180
// Loading and storing bytes
181181

javalib/src/main/scala/java/util/Formatter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ final class Formatter private (private[this] var dest: Appendable,
638638
val mbitsMask = ((1L << mbits) - 1L)
639639
val bias = (1 << (ebits - 1)) - 1
640640

641-
val bits = JDouble.doubleToLongBits(arg)
641+
val bits = JDouble.doubleToRawLongBits(arg)
642642
val negative = bits < 0
643643
val explicitMBits = bits & mbitsMask
644644
val biasedExponent = (bits >>> mbits).toInt & ((1 << ebits) - 1)

linker-private-library/src/main/scala/org/scalajs/linker/runtime/RuntimeLong.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ object RuntimeLong {
835835
}
836836
}
837837

838-
/** Computes `doubleToLongBits(value)`.
838+
/** Computes `doubleToRawLongBits(value)`.
839839
*
840840
* `fpBitsDataView` must be a scratch `js.typedarray.DataView` whose
841841
* underlying buffer is at least 8 bytes long.

0 commit comments

Comments
 (0)