Skip to content

Commit 849495b

Browse files
authored
Merge pull request #5209 from sjrd/really-canonicalize-nan-bit-patterns
Fix #5208: Introduce raw floating point bit manipulation.
2 parents b9b6a6d + 47498f2 commit 849495b

File tree

16 files changed

+219
-71
lines changed

16 files changed

+219
-71
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7486,11 +7486,11 @@ private object GenJSCode {
74867486
m("numberOfLeadingZeros", List(J), I) -> ArgUnaryOp(unop.Long_clz)
74877487
),
74887488
jswkn.BoxedFloatClass.withSuffix("$") -> Map(
7489-
m("floatToIntBits", List(F), I) -> ArgUnaryOp(unop.Float_toBits),
7489+
m("floatToRawIntBits", List(F), I) -> ArgUnaryOp(unop.Float_toBits),
74907490
m("intBitsToFloat", List(I), F) -> ArgUnaryOp(unop.Float_fromBits)
74917491
),
74927492
jswkn.BoxedDoubleClass.withSuffix("$") -> Map(
7493-
m("doubleToLongBits", List(D), J) -> ArgUnaryOp(unop.Double_toBits),
7493+
m("doubleToRawLongBits", List(D), J) -> ArgUnaryOp(unop.Double_toBits),
74947494
m("longBitsToDouble", List(J), D) -> ArgUnaryOp(unop.Double_fromBits)
74957495
),
74967496
jswkn.BoxedStringClass -> Map(

ir/shared/src/main/scala/org/scalajs/ir/Trees.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,17 +510,15 @@ object Trees {
510510
final val Throw = 31
511511

512512
// Floating point bit manipulation, introduced in 1.20
513-
final val Float_toBits = 32
514-
// final val Float_toRawBits = 33 // Reserved
515-
final val Float_fromBits = 34
516-
final val Double_toBits = 35
517-
// final val Double_toRawBits = 36 // Reserved
518-
final val Double_fromBits = 37
513+
final val Float_toBits = 32 // (this is the raw version, without any guarantee for NaN bit patterns)
514+
final val Float_fromBits = 33
515+
final val Double_toBits = 34 // (this is the raw version, without any guarantee for NaN bit patterns)
516+
final val Double_fromBits = 35
519517

520518
// Other nodes introduced in 1.20
521-
final val Int_clz = 38
522-
final val Long_clz = 39
523-
final val UnsignedIntToLong = 40
519+
final val Int_clz = 36
520+
final val Long_clz = 37
521+
final val UnsignedIntToLong = 38
524522

525523
def isClassOp(op: Code): Boolean =
526524
op >= Class_name && op <= Class_superClass

javalib/src/main/scala/java/io/DataOutputStream.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ class DataOutputStream(out: OutputStream)
6262
}
6363

6464
final def writeFloat(v: Float): Unit =
65-
writeInt(java.lang.Float.floatToIntBits(v))
65+
writeInt(java.lang.Float.floatToIntBits(v)) // must canonicalize NaNs
6666

6767
final def writeDouble(v: Double): Unit =
68-
writeLong(java.lang.Double.doubleToLongBits(v))
68+
writeLong(java.lang.Double.doubleToLongBits(v)) // must canonicalize NaNs
6969

7070
final def writeBytes(s: String): Unit = {
7171
for (i <- 0 until s.length())

javalib/src/main/scala/java/lang/Double.scala

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ object Double {
7474
final val SIZE = 64
7575
final val BYTES = 8
7676

77+
private final val PosInfinityBits = 0x7ff0000000000000L
78+
private final val CanonicalNaNBits = 0x7ff8000000000000L
79+
7780
@inline def `new`(value: scala.Double): Double = valueOf(value)
7881

7982
@inline def `new`(s: String): Double = valueOf(s)
@@ -280,7 +283,7 @@ object Double {
280283
val mbits = 52 // mantissa size
281284
val bias = (1 << (ebits - 1)) - 1
282285

283-
val bits = doubleToLongBits(d)
286+
val bits = doubleToRawLongBits(d)
284287
val s = bits < 0
285288
val m = bits & ((1L << mbits) - 1L)
286289
val e = (bits >>> mbits).toInt & ((1 << ebits) - 1) // biased
@@ -388,10 +391,12 @@ object Double {
388391

389392
@inline
390393
private def hashCodeForWasm(value: scala.Double): Int = {
391-
val bits = doubleToLongBits(value)
394+
val bits = doubleToRawLongBits(value)
392395
val valueInt = value.toInt
393-
if (doubleToLongBits(valueInt.toDouble) == bits)
396+
if (doubleToRawLongBits(valueInt.toDouble) == bits)
394397
valueInt
398+
else if (isNaNBitPattern(bits))
399+
Long.hashCode(CanonicalNaNBits)
395400
else
396401
Long.hashCode(bits)
397402
}
@@ -401,16 +406,41 @@ object Double {
401406
val valueInt = (value.asInstanceOf[js.Dynamic] | 0.asInstanceOf[js.Dynamic]).asInstanceOf[Int]
402407
if (valueInt.toDouble == value && 1.0/value != scala.Double.NegativeInfinity)
403408
valueInt
409+
else if (value != value)
410+
Long.hashCode(CanonicalNaNBits)
404411
else
405-
Long.hashCode(doubleToLongBits(value))
412+
Long.hashCode(doubleToRawLongBits(value))
406413
}
407414

408415
@inline def longBitsToDouble(bits: scala.Long): scala.Double =
409416
throw new Error("stub") // body replaced by the compiler back-end
410417

411-
@inline def doubleToLongBits(value: scala.Double): scala.Long =
418+
@inline def doubleToRawLongBits(value: scala.Double): scala.Long =
412419
throw new Error("stub") // body replaced by the compiler back-end
413420

421+
@inline def doubleToLongBits(value: scala.Double): scala.Long = {
422+
if (LinkingInfo.isWebAssembly) {
423+
val rawBits = doubleToRawLongBits(value)
424+
if (isNaNBitPattern(rawBits))
425+
CanonicalNaNBits
426+
else
427+
rawBits
428+
} else {
429+
/* On JS, the Long comparison inside isNaNBitPattern is expensive.
430+
* We compare to NaN at the double level instead.
431+
*/
432+
if (value != value)
433+
CanonicalNaNBits
434+
else
435+
doubleToRawLongBits(value)
436+
}
437+
}
438+
439+
@inline private def isNaNBitPattern(bits: scala.Long): scala.Boolean = {
440+
// Both operands are non-negative; it does not matter whether the comparison is signed or not
441+
(bits & ~scala.Long.MinValue) > PosInfinityBits
442+
}
443+
414444
@inline def sum(a: scala.Double, b: scala.Double): scala.Double =
415445
a + b
416446

javalib/src/main/scala/java/lang/Float.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ object Float {
7272
final val SIZE = 32
7373
final val BYTES = 4
7474

75+
private final val PosInfinityBits = 0x7f800000
76+
private final val CanonicalNaNBits = 0x7fc00000
77+
7578
@inline def `new`(value: scala.Float): Float = valueOf(value)
7679

7780
@inline def `new`(value: scala.Double): Float = valueOf(value.toFloat)
@@ -263,7 +266,7 @@ object Float {
263266
val kbits = 11 // number of bits of the exponent
264267
val bias = (1 << (kbits - 1)) - 1 // the bias of the exponent
265268

266-
val midBits = Double.doubleToLongBits(mid)
269+
val midBits = Double.doubleToRawLongBits(mid)
267270
val biasedK = (midBits >> mbits).toInt
268271

269272
/* Because `mid` is a double value halfway between two floats, it cannot
@@ -300,7 +303,7 @@ object Float {
300303
zDown
301304
else if (cmp > 0)
302305
zUp
303-
else if ((floatToIntBits(zDown) & 1) == 0) // zDown is even
306+
else if ((floatToRawIntBits(zDown) & 1) == 0) // zDown is even
304307
zDown
305308
else
306309
zUp
@@ -430,9 +433,22 @@ object Float {
430433
@inline def intBitsToFloat(bits: scala.Int): scala.Float =
431434
throw new Error("stub") // body replaced by the compiler back-end
432435

433-
@inline def floatToIntBits(value: scala.Float): scala.Int =
436+
@inline def floatToRawIntBits(value: scala.Float): scala.Int =
434437
throw new Error("stub") // body replaced by the compiler back-end
435438

439+
@inline def floatToIntBits(value: scala.Float): scala.Int = {
440+
val rawBits = floatToRawIntBits(value)
441+
if (isNaNBitPattern(rawBits))
442+
CanonicalNaNBits
443+
else
444+
rawBits
445+
}
446+
447+
@inline private def isNaNBitPattern(bits: scala.Int): scala.Boolean = {
448+
// Both operands are non-negative; it does not matter whether the comparison is signed or not
449+
(bits & ~Int.MinValue) > PosInfinityBits
450+
}
451+
436452
@inline def sum(a: scala.Float, b: scala.Float): scala.Float =
437453
a + b
438454

javalib/src/main/scala/java/lang/Math.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ object Math {
178178
} else if (a == -0.0) { // also matches +0.0 but that's fine
179179
scala.Double.MinPositiveValue
180180
} else {
181-
val abits = Double.doubleToLongBits(a)
181+
val abits = Double.doubleToRawLongBits(a)
182182
val rbits = if (a > 0) abits + 1L else abits - 1L
183183
Double.longBitsToDouble(rbits)
184184
}
@@ -190,7 +190,7 @@ object Math {
190190
} else if (a == -0.0f) { // also matches +0.0f but that's fine
191191
scala.Float.MinPositiveValue
192192
} else {
193-
val abits = Float.floatToIntBits(a)
193+
val abits = Float.floatToRawIntBits(a)
194194
val rbits = if (a > 0) abits + 1 else abits - 1
195195
Float.intBitsToFloat(rbits)
196196
}
@@ -202,7 +202,7 @@ object Math {
202202
} else if (a == 0.0) { // also matches -0.0 but that's fine
203203
-scala.Double.MinPositiveValue
204204
} else {
205-
val abits = Double.doubleToLongBits(a)
205+
val abits = Double.doubleToRawLongBits(a)
206206
val rbits = if (a > 0) abits - 1L else abits + 1L
207207
Double.longBitsToDouble(rbits)
208208
}
@@ -214,7 +214,7 @@ object Math {
214214
} else if (a == 0.0f) { // also matches -0.0f but that's fine
215215
-scala.Float.MinPositiveValue
216216
} else {
217-
val abits = Float.floatToIntBits(a)
217+
val abits = Float.floatToRawIntBits(a)
218218
val rbits = if (a > 0) abits - 1 else abits + 1
219219
Float.intBitsToFloat(rbits)
220220
}

javalib/src/main/scala/java/math/BigDecimal.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ class BigDecimal() extends Number with Comparable[BigDecimal] {
504504
if (JDouble.isInfinite(dVal) || JDouble.isNaN(dVal))
505505
throw new NumberFormatException("Infinity or NaN: " + dVal)
506506

507-
val bits = java.lang.Double.doubleToLongBits(dVal)
507+
val bits = java.lang.Double.doubleToRawLongBits(dVal)
508508
// Extracting the exponent, note that the bias is 1023
509509
_scale = 1075 - ((bits >> 52) & 2047).toInt
510510
// Extracting the 52 bits of the mantissa.

javalib/src/main/scala/java/nio/ByteArrayBits.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ private[nio] final class ByteArrayBits(
170170

171171
@inline
172172
private def unmakeFloat(f: Float): (Byte, Byte, Byte, Byte) =
173-
unmakeInt(java.lang.Float.floatToIntBits(f))
173+
unmakeInt(java.lang.Float.floatToRawIntBits(f)) // NaN bit patterns are unspecified here
174174

175175
@inline
176176
private def unmakeDouble(
177177
d: Double): (Byte, Byte, Byte, Byte, Byte, Byte, Byte, Byte) =
178-
unmakeLong(java.lang.Double.doubleToLongBits(d))
178+
unmakeLong(java.lang.Double.doubleToRawLongBits(d)) // NaN bit patterns are unspecified here
179179

180180
// Loading and storing bytes
181181

javalib/src/main/scala/java/util/Formatter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ final class Formatter private (private[this] var dest: Appendable,
638638
val mbitsMask = ((1L << mbits) - 1L)
639639
val bias = (1 << (ebits - 1)) - 1
640640

641-
val bits = JDouble.doubleToLongBits(arg)
641+
val bits = JDouble.doubleToRawLongBits(arg)
642642
val negative = bits < 0
643643
val explicitMBits = bits & mbitsMask
644644
val biasedExponent = (bits >>> mbits).toInt & ((1 << ebits) - 1)

linker-private-library/src/main/scala/org/scalajs/linker/runtime/RuntimeLong.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ object RuntimeLong {
835835
}
836836
}
837837

838-
/** Computes `doubleToLongBits(value)`.
838+
/** Computes `doubleToRawLongBits(value)`.
839839
*
840840
* `fpBitsDataView` must be a scratch `js.typedarray.DataView` whose
841841
* underlying buffer is at least 8 bytes long.

0 commit comments

Comments
 (0)