Skip to content

Commit afdda6f

Browse files
committed
Wasm: Implement PriorityQueue without js.Array in Wasm backend
Replaced the `js.Array`-based `PriorityQueue` backend with a `scala.Array` implementation when targeting Wasm. The previous js.Array based PriorityQueue implementation was not very performant on Wasm, because of the overhead of JS-interop for each inner array operations. A `linkTimeIf` condition now selects the appropriate array implementation at build time. Benchmarks show a 2-4x speedup for `add`, `peek`, and `poll` operations on the Wasm backend. tanishiking/scalajs-benchmarks#3
1 parent 94acdf8 commit afdda6f

File tree

1 file changed

+175
-50
lines changed

1 file changed

+175
-50
lines changed

javalib/src/main/scala/java/util/PriorityQueue.scala

Lines changed: 175 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,49 @@
1212

1313
package java.util
1414

15+
import scala.language.higherKinds
16+
1517
import scala.annotation.tailrec
1618

17-
import scala.scalajs.js
19+
import java.lang.Utils.roundUpToPowerOfTwo
20+
21+
import scala.scalajs.LinkingInfo
1822

1923
class PriorityQueue[E] private (
20-
private val comp: Comparator[_ >: E], internal: Boolean)
24+
private val comp: Comparator[_ >: E], internal: Boolean, initialCapacity: Int)
2125
extends AbstractQueue[E] with Serializable {
2226

27+
import PriorityQueue._
28+
2329
def this() =
24-
this(NaturalComparator, internal = true)
30+
this(NaturalComparator, internal = true, initialCapacity = 16)
2531

2632
def this(initialCapacity: Int) = {
27-
this()
28-
if (initialCapacity < 1)
29-
throw new IllegalArgumentException()
33+
this(
34+
NaturalComparator,
35+
internal = true,
36+
{
37+
if (initialCapacity < 1)
38+
throw new IllegalArgumentException
39+
initialCapacity + 1 // index 0 is unused
40+
}
41+
)
3042
}
3143

3244
def this(comparator: Comparator[_ >: E]) = {
33-
this(NaturalComparator.select(comparator), internal = true)
45+
this(NaturalComparator.select(comparator), internal = true, initialCapacity = 16)
3446
}
3547

3648
def this(initialCapacity: Int, comparator: Comparator[_ >: E]) = {
37-
this(comparator)
38-
if (initialCapacity < 1)
39-
throw new IllegalArgumentException()
49+
this(
50+
NaturalComparator.select(comparator),
51+
internal = true,
52+
{
53+
if (initialCapacity < 1)
54+
throw new IllegalArgumentException()
55+
initialCapacity + 1 // index 0 is unused
56+
}
57+
)
4058
}
4159

4260
def this(c: Collection[_ <: E]) = {
@@ -47,47 +65,51 @@ class PriorityQueue[E] private (
4765
NaturalComparator.select(c.comparator().asInstanceOf[Comparator[_ >: E]])
4866
case _ =>
4967
NaturalComparator
50-
}, internal = true)
68+
}, internal = true, roundUpToPowerOfTwo(c.size() + 1)) // index 0 is unused
5169
addAll(c)
5270
}
5371

5472
def this(c: PriorityQueue[_ <: E]) = {
55-
this(c.comp.asInstanceOf[Comparator[_ >: E]], internal = true)
73+
this(c.comp.asInstanceOf[Comparator[_ >: E]], internal = true,
74+
roundUpToPowerOfTwo(c.size() + 1)) // index 0 is unused
5675
addAll(c)
5776
}
5877

5978
def this(sortedSet: SortedSet[_ <: E]) = {
6079
this(NaturalComparator.select(
6180
sortedSet.comparator().asInstanceOf[Comparator[_ >: E]]),
62-
internal = true)
81+
internal = true,
82+
roundUpToPowerOfTwo(sortedSet.size() + 1)) // index 0 is unused
6383
addAll(sortedSet)
6484
}
6585

6686
// The index 0 is not used; the root is at index 1.
6787
// This is standard practice in binary heaps, to simplify arithmetics.
68-
private[this] val inner = js.Array[E](null.asInstanceOf[E])
88+
private var inner: innerImpl.Repr[E] = innerImpl.make[E](initialCapacity)
6989

7090
override def add(e: E): Boolean = {
7191
if (e == null)
7292
throw new NullPointerException()
73-
inner.push(e)
74-
fixUp(inner.length - 1)
93+
val newInner = innerImpl.push(inner, e)
94+
if (LinkingInfo.isWebAssembly) // opt: for JS we know it's always the same
95+
inner = newInner
96+
fixUp(innerImpl.length(inner) - 1)
7597
true
7698
}
7799

78100
def offer(e: E): Boolean = add(e)
79101

80102
def peek(): E =
81-
if (inner.length > 1) inner(1)
103+
if (innerImpl.length(inner) > 1) innerImpl.get(inner, 1)
82104
else null.asInstanceOf[E]
83105

84106
override def remove(o: Any): Boolean = {
85107
if (o == null) {
86108
false
87109
} else {
88-
val len = inner.length
110+
val len = innerImpl.length(inner)
89111
var i = 1
90-
while (i != len && !o.equals(inner(i))) {
112+
while (i != len && !o.equals(innerImpl.get(inner, i))) {
91113
i += 1
92114
}
93115

@@ -101,9 +123,9 @@ class PriorityQueue[E] private (
101123
}
102124

103125
private def removeExact(o: Any): Unit = {
104-
val len = inner.length
126+
val len = innerImpl.length(inner)
105127
var i = 1
106-
while (i != len && (o.asInstanceOf[AnyRef] ne inner(i).asInstanceOf[AnyRef])) {
128+
while (i != len && (o.asInstanceOf[AnyRef] ne innerImpl.get(inner, i).asInstanceOf[AnyRef])) {
107129
i += 1
108130
}
109131
if (i == len)
@@ -112,12 +134,12 @@ class PriorityQueue[E] private (
112134
}
113135

114136
private def removeAt(i: Int): Unit = {
115-
val newLength = inner.length - 1
137+
val newLength = innerImpl.length(inner) - 1
116138
if (i == newLength) {
117-
inner.length = newLength
139+
innerImpl.decLength(inner)
118140
} else {
119-
inner(i) = inner(newLength)
120-
inner.length = newLength
141+
innerImpl.set(inner, i, innerImpl.get(inner, newLength))
142+
innerImpl.decLength(inner)
121143
fixUpOrDown(i)
122144
}
123145
}
@@ -126,9 +148,9 @@ class PriorityQueue[E] private (
126148
if (o == null) {
127149
false
128150
} else {
129-
val len = inner.length
151+
val len = innerImpl.length(inner)
130152
var i = 1
131-
while (i != len && !o.equals(inner(i))) {
153+
while (i != len && !o.equals(innerImpl.get(inner, i))) {
132154
i += 1
133155
}
134156
i != len
@@ -137,16 +159,16 @@ class PriorityQueue[E] private (
137159

138160
def iterator(): Iterator[E] = {
139161
new Iterator[E] {
140-
private[this] var inner: js.Array[E] = PriorityQueue.this.inner
162+
private[this] var inner: innerImpl.Repr[E] = PriorityQueue.this.inner
141163
private[this] var nextIdx: Int = 1
142164
private[this] var last: E = _ // null
143165

144-
def hasNext(): Boolean = nextIdx < inner.length
166+
def hasNext(): Boolean = nextIdx < innerImpl.length(inner)
145167

146168
def next(): E = {
147169
if (!hasNext())
148170
throw new NoSuchElementException("empty iterator")
149-
last = inner(nextIdx)
171+
last = innerImpl.get(inner, nextIdx)
150172
nextIdx += 1
151173
last
152174
}
@@ -173,27 +195,27 @@ class PriorityQueue[E] private (
173195
if (last == null)
174196
throw new IllegalStateException()
175197
if (inner eq PriorityQueue.this.inner) {
176-
inner = inner.jsSlice(nextIdx)
177-
nextIdx = 0
198+
inner = innerImpl.copyFrom(inner, nextIdx)
199+
nextIdx = 1
178200
}
179201
removeExact(last)
180202
last = null.asInstanceOf[E]
181203
}
182204
}
183205
}
184206

185-
def size(): Int = inner.length - 1
207+
def size(): Int = innerImpl.length(inner) - 1
186208

187209
override def clear(): Unit =
188-
inner.length = 1
210+
innerImpl.clear(inner)
189211

190212
def poll(): E = {
191213
val inner = this.inner // local copy
192-
if (inner.length > 1) {
193-
val newSize = inner.length - 1
194-
val result = inner(1)
195-
inner(1) = inner(newSize)
196-
inner.length = newSize
214+
if (innerImpl.length(inner) > 1) {
215+
val newSize = innerImpl.length(inner) - 1
216+
val result = innerImpl.get(inner, 1)
217+
innerImpl.set(inner, 1, innerImpl.get(inner, newSize))
218+
innerImpl.decLength(inner)
197219
fixDown(1)
198220
result
199221
} else {
@@ -212,7 +234,7 @@ class PriorityQueue[E] private (
212234
*/
213235
private[this] def fixUpOrDown(m: Int): Unit = {
214236
val inner = this.inner // local copy
215-
if (m > 1 && comp.compare(inner(m >> 1), inner(m)) > 0)
237+
if (m > 1 && comp.compare(innerImpl.get(inner, m >> 1), innerImpl.get(inner, m)) > 0)
216238
fixUp(m)
217239
else
218240
fixDown(m)
@@ -227,16 +249,16 @@ class PriorityQueue[E] private (
227249
/* At each step, even though `m` changes, the element moves with it, and
228250
* hence inner(m) is always the same initial `innerAtM`.
229251
*/
230-
val innerAtM = inner(m)
252+
val innerAtM = innerImpl.get(inner, m)
231253

232254
@inline @tailrec
233255
def loop(m: Int): Unit = {
234256
if (m > 1) {
235257
val parent = m >> 1
236-
val innerAtParent = inner(parent)
258+
val innerAtParent = innerImpl.get(inner, parent)
237259
if (comp.compare(innerAtParent, innerAtM) > 0) {
238-
inner(parent) = innerAtM
239-
inner(m) = innerAtParent
260+
innerImpl.set(inner, parent, innerAtM)
261+
innerImpl.set(inner, m, innerAtParent)
240262
loop(parent)
241263
}
242264
}
@@ -250,22 +272,22 @@ class PriorityQueue[E] private (
250272
*/
251273
private[this] def fixDown(m: Int): Unit = {
252274
val inner = this.inner // local copy
253-
val size = inner.length - 1
275+
val size = innerImpl.length(inner) - 1
254276

255277
/* At each step, even though `m` changes, the element moves with it, and
256278
* hence inner(m) is always the same initial `innerAtM`.
257279
*/
258-
val innerAtM = inner(m)
280+
val innerAtM = innerImpl.get(inner, m)
259281

260282
@inline @tailrec
261283
def loop(m: Int): Unit = {
262284
var j = 2 * m // left child of `m`
263285
if (j <= size) {
264-
var innerAtJ = inner(j)
286+
var innerAtJ = innerImpl.get(inner, j)
265287

266288
// if the left child is greater than the right child, switch to the right child
267289
if (j < size) {
268-
val innerAtJPlus1 = inner(j + 1)
290+
val innerAtJPlus1 = innerImpl.get(inner, j + 1)
269291
if (comp.compare(innerAtJ, innerAtJPlus1) > 0) {
270292
j += 1
271293
innerAtJ = innerAtJPlus1
@@ -274,13 +296,116 @@ class PriorityQueue[E] private (
274296

275297
// if the node `m` is greater than the selected child, swap and recurse
276298
if (comp.compare(innerAtM, innerAtJ) > 0) {
277-
inner(m) = innerAtJ
278-
inner(j) = innerAtM
299+
innerImpl.set(inner, m, innerAtJ)
300+
innerImpl.set(inner, j, innerAtM)
279301
loop(j)
280302
}
281303
}
282304
}
283305

284306
loop(m)
285307
}
308+
309+
}
310+
311+
object PriorityQueue {
312+
313+
/* Get the best available implementation of inner array for the given platform.
314+
*
315+
* Use Array[AnyRef] in WebAssembly to avoid JS-interop. In JS, use js.Array.
316+
* It is resizable by nature, so manual resizing is not needed.
317+
*
318+
* `linkTimeIf` is needed here to ensure the optimizer knows
319+
* there is only one implementation of `InnerArrayImpl`, and de-virtualize/inline
320+
* the function calls.
321+
*/
322+
323+
private val innerImpl: InnerArrayImpl = {
324+
LinkingInfo.linkTimeIf[InnerArrayImpl](LinkingInfo.isWebAssembly) {
325+
InnerArrayImpl.JArrayImpl
326+
} {
327+
InnerArrayImpl.JSArrayImpl
328+
}
329+
}
330+
331+
private sealed abstract class InnerArrayImpl {
332+
type Repr[E] <: AnyRef
333+
334+
def make[E](initialCapacity: Int): Repr[E]
335+
def length(v: Repr[_]): Int
336+
def decLength(v: Repr[_]): Unit
337+
def get[E](v: Repr[E], index: Int): E
338+
def set[E](v: Repr[E], index: Int, e: E): Unit
339+
def push[E](v: Repr[E], e: E): Repr[E]
340+
def copyFrom[E](v: Repr[E], from: Int): Repr[E]
341+
def clear(v: Repr[_]): Unit
342+
}
343+
344+
private object InnerArrayImpl {
345+
object JSArrayImpl extends InnerArrayImpl {
346+
import scala.scalajs.js
347+
348+
type Repr[E] = js.Array[AnyRef]
349+
350+
@inline def make[E](_initialCapacity: Int): Repr[E] = js.Array[AnyRef](null)
351+
@inline def length(v: Repr[_]): Int = v.length
352+
@inline def decLength(v: Repr[_]): Unit =
353+
v.length = v.length - 1
354+
@inline def get[E](v: Repr[E], index: Int): E = v(index).asInstanceOf[E]
355+
@inline def set[E](v: Repr[E], index: Int, e: E): Unit =
356+
v(index) = e.asInstanceOf[AnyRef]
357+
@inline def push[E](v: Repr[E], e: E): Repr[E] = {
358+
v.push(e.asInstanceOf[AnyRef])
359+
v
360+
}
361+
@inline def copyFrom[E](v: Repr[E], from: Int): Repr[E] = v.jsSlice(from - 1)
362+
@inline def clear(v: Repr[_]): Unit =
363+
v.length = 1
364+
}
365+
366+
/* We store the effective length in the index 0 of the array,
367+
* which is unused both in JSArrayImpl and in this impl.
368+
*/
369+
object JArrayImpl extends InnerArrayImpl {
370+
type Repr[E] = Array[AnyRef]
371+
372+
@inline def make[E](initialCapacity: Int): Repr[E] = {
373+
val v = new Array[AnyRef](initialCapacity)
374+
v(0) = 1.asInstanceOf[AnyRef]
375+
v
376+
}
377+
@inline def length(v: Repr[_]): Int = v(0).asInstanceOf[Int]
378+
@inline def decLength(v: Repr[_]): Unit = {
379+
val newLength = length(v) - 1
380+
v(0) = newLength.asInstanceOf[AnyRef]
381+
v(newLength) = null // free reference for GC
382+
}
383+
@inline def get[E](v: Repr[E], index: Int): E = v(index).asInstanceOf[E]
384+
@inline def set[E](v: Repr[E], index: Int, e: E): Unit =
385+
v(index) = e.asInstanceOf[AnyRef]
386+
@inline def push[E](v: Repr[E], e: E): Repr[E] = {
387+
val l = length(v)
388+
val minCapacity = l + 1
389+
val newArr =
390+
if (v.length < minCapacity)
391+
Arrays.copyOf(v, roundUpToPowerOfTwo(minCapacity))
392+
else v
393+
newArr(l) = e.asInstanceOf[AnyRef]
394+
newArr(0) = (l + 1).asInstanceOf[AnyRef]
395+
newArr
396+
}
397+
@inline def copyFrom[E](v: Repr[E], from: Int): Repr[E] = {
398+
val elemLength = length(v) - from
399+
val newArr = new Array[AnyRef](elemLength + 1)
400+
newArr(0) = (elemLength + 1).asInstanceOf[AnyRef]
401+
System.arraycopy(v, from, newArr, 1, elemLength)
402+
newArr
403+
}
404+
@inline def clear(v: Repr[_]): Unit = {
405+
Arrays.fill(v, null)
406+
v(0) = 1.asInstanceOf[AnyRef]
407+
}
408+
}
409+
}
410+
286411
}

0 commit comments

Comments
 (0)