Skip to content

Commit fc1a80c

Browse files
authored
Merge pull request scala-js#4367 from exoego/newkeyset
Implement ConcurrentHashMap.newKeySet and use it in VersionChecks
2 parents b859721 + 7196543 commit fc1a80c

File tree

6 files changed

+227
-61
lines changed

6 files changed

+227
-61
lines changed

ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ class VersionChecks private[ir] (
5252
}
5353

5454
private val knownSupportedBinary = {
55-
val m = new ConcurrentHashMap[String, Unit]()
56-
m.put(binaryEmitted, ())
55+
val m = ConcurrentHashMap.newKeySet[String]()
56+
m.add(binaryEmitted)
5757
m
5858
}
5959

6060
/** Check we can support this binary version (used by deserializer) */
6161
final def checkSupported(version: String): Unit = {
62-
if (!knownSupportedBinary.containsKey(version)) {
62+
if (!knownSupportedBinary.contains(version)) {
6363
val (major, minor, preRelease) = parseBinary(version)
6464
val supported = (
6565
// the exact pre-release version is supported via knownSupportedBinary
@@ -70,7 +70,7 @@ class VersionChecks private[ir] (
7070
)
7171

7272
if (supported) {
73-
knownSupportedBinary.put(version, ())
73+
knownSupportedBinary.add(version)
7474
} else {
7575
throw new IRVersionNotSupportedException(version, binaryEmitted,
7676
s"This version ($version) of Scala.js IR is not supported. " +

javalib/src/main/scala/java/util/NullRejectingHashMap.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ private[util] class NullRejectingHashMap[K, V](
5757
super.put(key, value)
5858
}
5959

60+
override def putIfAbsent(key: K, value: V): V = {
61+
if (value == null)
62+
throw new NullPointerException()
63+
val old = get(key) // throws if `key` is null
64+
if (old == null)
65+
super.put(key, value)
66+
old
67+
}
68+
6069
@noinline
6170
override def putAll(m: Map[_ <: K, _ <: V]): Unit = {
6271
/* The only purpose of `impl` is to capture the wildcards as named types,
@@ -80,6 +89,37 @@ private[util] class NullRejectingHashMap[K, V](
8089
super.remove(key)
8190
}
8291

92+
override def remove(key: Any, value: Any): Boolean = {
93+
val old = get(key) // throws if `key` is null
94+
if (old != null && old.equals(value)) { // false if `value` is null
95+
super.remove(key)
96+
true
97+
} else {
98+
false
99+
}
100+
}
101+
102+
override def replace(key: K, oldValue: V, newValue: V): Boolean = {
103+
if (oldValue == null || newValue == null)
104+
throw new NullPointerException()
105+
val old = get(key) // throws if `key` is null
106+
if (oldValue.equals(old)) { // false if `old` is null
107+
super.put(key, newValue)
108+
true
109+
} else {
110+
false
111+
}
112+
}
113+
114+
override def replace(key: K, value: V): V = {
115+
if (value == null)
116+
throw new NullPointerException()
117+
val old = get(key) // throws if `key` is null
118+
if (old != null)
119+
super.put(key, value)
120+
old
121+
}
122+
83123
override def containsValue(value: Any): Boolean = {
84124
if (value == null)
85125
throw new NullPointerException()

javalib/src/main/scala/java/util/concurrent/ConcurrentHashMap.scala

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,16 @@ class ConcurrentHashMap[K, V] private (initialCapacity: Int, loadFactor: Float)
6161
override def clear(): Unit =
6262
inner.clear()
6363

64-
override def keySet(): ConcurrentHashMap.KeySetView[K, V] =
65-
new ConcurrentHashMap.KeySetView[K, V](inner.keySet())
64+
override def keySet(): ConcurrentHashMap.KeySetView[K, V] = {
65+
// Allow null as sentinel
66+
new ConcurrentHashMap.KeySetView[K, V](this.inner, null.asInstanceOf[V])
67+
}
68+
69+
def keySet(mappedValue: V): ConcurrentHashMap.KeySetView[K, V] = {
70+
if (mappedValue == null)
71+
throw new NullPointerException()
72+
new ConcurrentHashMap.KeySetView[K, V](this.inner, mappedValue)
73+
}
6674

6775
override def values(): Collection[V] =
6876
inner.values()
@@ -79,45 +87,17 @@ class ConcurrentHashMap[K, V] private (initialCapacity: Int, loadFactor: Float)
7987
override def equals(o: Any): Boolean =
8088
inner.equals(o)
8189

82-
override def putIfAbsent(key: K, value: V): V = {
83-
if (value == null)
84-
throw new NullPointerException()
85-
val old = inner.get(key) // throws if `key` is null
86-
if (old == null)
87-
inner.put(key, value)
88-
old
89-
}
90+
override def putIfAbsent(key: K, value: V): V =
91+
inner.putIfAbsent(key, value)
9092

91-
override def remove(key: Any, value: Any): Boolean = {
92-
val old = inner.get(key) // throws if `key` is null
93-
if (old != null && old.equals(value)) { // false if `value` is null
94-
inner.remove(key)
95-
true
96-
} else {
97-
false
98-
}
99-
}
93+
override def remove(key: Any, value: Any): Boolean =
94+
inner.remove(key, value)
10095

101-
override def replace(key: K, oldValue: V, newValue: V): Boolean = {
102-
if (oldValue == null || newValue == null)
103-
throw new NullPointerException()
104-
val old = inner.get(key) // throws if `key` is null
105-
if (oldValue.equals(old)) { // false if `old` is null
106-
inner.put(key, newValue)
107-
true
108-
} else {
109-
false
110-
}
111-
}
96+
override def replace(key: K, oldValue: V, newValue: V): Boolean =
97+
inner.replace(key, oldValue, newValue)
11298

113-
override def replace(key: K, value: V): V = {
114-
if (value == null)
115-
throw new NullPointerException()
116-
val old = inner.get(key) // throws if `key` is null
117-
if (old != null)
118-
inner.put(key, value)
119-
old
120-
}
99+
override def replace(key: K, value: V): V =
100+
inner.replace(key, value)
121101

122102
def contains(value: Any): Boolean =
123103
containsValue(value)
@@ -198,39 +178,58 @@ object ConcurrentHashMap {
198178
}
199179
}
200180

201-
/* `KeySetView` is a public class in the JDK API. The result of
202-
* `ConcurrentHashMap.keySet()` must be statically typed as a `KeySetView`,
203-
* hence the existence of this class, although it forwards all its operations
204-
* to the inner key set.
205-
*/
206-
class KeySetView[K, V] private[ConcurrentHashMap] (inner: Set[K])
181+
class KeySetView[K, V] private[ConcurrentHashMap] (innerMap: InnerHashMap[K, V], defaultValue: V)
207182
extends Set[K] with Serializable {
208183

209-
def contains(o: Any): Boolean = inner.contains(o)
184+
def getMappedValue(): V = defaultValue
210185

211-
def remove(o: Any): Boolean = inner.remove(o)
186+
def contains(o: Any): Boolean = innerMap.containsKey(o)
212187

213-
def iterator(): Iterator[K] = inner.iterator()
188+
def remove(o: Any): Boolean = innerMap.remove(o) != null
214189

215-
def size(): Int = inner.size()
190+
def iterator(): Iterator[K] = innerMap.keySet().iterator()
216191

217-
def isEmpty(): Boolean = inner.isEmpty()
192+
def size(): Int = innerMap.size()
218193

219-
def toArray(): Array[AnyRef] = inner.toArray()
194+
def isEmpty(): Boolean = innerMap.isEmpty()
220195

221-
def toArray[T <: AnyRef](a: Array[T]): Array[T] = inner.toArray[T](a)
196+
def toArray(): Array[AnyRef] = innerMap.keySet().toArray()
222197

223-
def add(e: K): Boolean = inner.add(e)
198+
def toArray[T <: AnyRef](a: Array[T]): Array[T] = innerMap.keySet().toArray(a)
224199

225-
def containsAll(c: Collection[_]): Boolean = inner.containsAll(c)
200+
def add(e: K): Boolean = {
201+
if (defaultValue == null) {
202+
throw new UnsupportedOperationException()
203+
}
204+
innerMap.putIfAbsent(e, defaultValue) == null
205+
}
226206

227-
def addAll(c: Collection[_ <: K]): Boolean = inner.addAll(c)
207+
override def toString(): String = innerMap.keySet().toString
228208

229-
def removeAll(c: Collection[_]): Boolean = inner.removeAll(c)
209+
def containsAll(c: Collection[_]): Boolean = innerMap.keySet().containsAll(c)
230210

231-
def retainAll(c: Collection[_]): Boolean = inner.retainAll(c)
211+
def addAll(c: Collection[_ <: K]): Boolean = {
212+
if (defaultValue == null) {
213+
throw new UnsupportedOperationException()
214+
}
215+
val iter = c.iterator()
216+
var changed = false
217+
while (iter.hasNext())
218+
changed = innerMap.putIfAbsent(iter.next(), defaultValue) == null || changed
219+
changed
220+
}
221+
222+
def removeAll(c: Collection[_]): Boolean = innerMap.keySet().removeAll(c)
223+
224+
def retainAll(c: Collection[_]): Boolean = innerMap.keySet().retainAll(c)
232225

233-
def clear(): Unit = inner.clear()
226+
def clear(): Unit = innerMap.clear()
234227
}
235228

229+
def newKeySet[K](): KeySetView[K, Boolean] = newKeySet[K](HashMap.DEFAULT_INITIAL_CAPACITY)
230+
231+
def newKeySet[K](initialCapacity: Int): KeySetView[K, Boolean] = {
232+
val inner = new InnerHashMap[K, Boolean](initialCapacity, HashMap.DEFAULT_LOAD_FACTOR)
233+
new KeySetView[K, Boolean](inner, true)
234+
}
236235
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Scala.js (https://www.scala-js.org/)
3+
*
4+
* Copyright EPFL.
5+
*
6+
* Licensed under Apache License 2.0
7+
* (https://www.apache.org/licenses/LICENSE-2.0).
8+
*
9+
* See the NOTICE file distributed with this work for
10+
* additional information regarding copyright ownership.
11+
*/
12+
13+
package org.scalajs.testsuite.javalib.util
14+
15+
import java.{util => ju}
16+
import scala.reflect.ClassTag
17+
18+
class ConcurrentHashMapKeySetViewTest extends SetTest {
19+
override def factory: ConcurrentHashMapKeySetViewFactory = new ConcurrentHashMapKeySetViewFactory
20+
}
21+
22+
class ConcurrentHashMapKeySetViewFactory extends SetFactory {
23+
override def implementationName: String =
24+
"java.util.ConcurrentHashMap.KeySetView"
25+
26+
override def allowsNullElementQuery: Boolean = false
27+
override def allowsNullElement: Boolean = false
28+
29+
override def empty[E: ClassTag]: ju.Set[E] =
30+
ju.concurrent.ConcurrentHashMap.newKeySet[E]()
31+
}

test-suite/shared/src/test/scala/org/scalajs/testsuite/javalib/util/MapTest.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ trait MapTest {
5858
assertEquals(null, mp.get(testObj(42)))
5959
if (factory.allowsNullKeysQueries)
6060
assertEquals(null, mp.get(null))
61+
else
62+
assertThrows(classOf[NullPointerException], mp.get(null))
6163
}
6264

6365
@Test def testSizeGetPutWithStringsLargeMap(): Unit = {
@@ -199,6 +201,8 @@ trait MapTest {
199201
assertNull(mp.remove(testObj(42)))
200202
if (factory.allowsNullKeys)
201203
assertNull(mp.remove(null))
204+
else
205+
assertThrows(classOf[NullPointerException], mp.remove(null))
202206
}
203207

204208
@Test def testRemoveWithInts(): Unit = {
@@ -1200,6 +1204,13 @@ trait MapTest {
12001204
mp.put("nullable", null)
12011205
assertNull(mp.putIfAbsent("nullable", "non null"))
12021206
assertEquals("non null", mp.get("nullable"))
1207+
} else {
1208+
assertThrows(classOf[NullPointerException], mp.putIfAbsent("abc", null))
1209+
assertThrows(classOf[NullPointerException], mp.putIfAbsent("new key", null))
1210+
}
1211+
1212+
if (!factory.allowsNullKeys) {
1213+
assertThrows(classOf[NullPointerException], mp.putIfAbsent(null, "def"))
12031214
}
12041215
}
12051216

@@ -1217,12 +1228,44 @@ trait MapTest {
12171228
assertTrue(mp.remove("ONE", "one"))
12181229
assertFalse(mp.containsKey("ONE"))
12191230

1231+
if (factory.allowsNullKeys) {
1232+
mp.put(null, "one")
1233+
assertFalse(mp.remove(null, "not exist"))
1234+
assertTrue(mp.containsKey(null))
1235+
assertTrue(mp.remove(null, "one"))
1236+
assertFalse(mp.containsKey(null))
1237+
} else {
1238+
assertThrows(classOf[NullPointerException], mp.remove(null, "old value"))
1239+
}
1240+
12201241
if (factory.allowsNullValues) {
12211242
mp.put("nullable", null)
12221243
assertFalse(mp.remove("nullable", "value"))
12231244
assertTrue(mp.containsKey("nullable"))
12241245
assertTrue(mp.remove("nullable", null))
12251246
assertFalse(mp.containsKey("nullable"))
1247+
} else {
1248+
// mp#(key, null) should not remove. https://bugs.java.com/bugdatabase/view_bug.do?bug_id=6272521
1249+
assertFalse(mp.remove("THREE", null))
1250+
}
1251+
}
1252+
1253+
@Test def testUnconditionalRemove(): Unit = {
1254+
val mp = factory.fromKeyValuePairs("ONE" -> "one", "TWO" -> "two", "THREE" -> "three")
1255+
1256+
assertEquals(null, mp.remove("non existing"))
1257+
assertFalse(mp.containsKey("non existing"))
1258+
1259+
assertEquals("two", mp.remove("TWO"))
1260+
assertEquals(null, mp.get("TWO"))
1261+
1262+
if (factory.allowsNullKeys) {
1263+
mp.put(null, "one")
1264+
assertTrue(mp.containsKey(null))
1265+
assertEquals("one", mp.remove(null))
1266+
assertFalse(mp.containsKey(null))
1267+
} else {
1268+
assertThrows(classOf[NullPointerException], mp.remove(null))
12261269
}
12271270
}
12281271

@@ -1252,6 +1295,17 @@ trait MapTest {
12521295
assertThrows(classOf[NullPointerException], mp.replace("ONE", null, "one"))
12531296
assertThrows(classOf[NullPointerException], mp.replace("ONE", "four", null))
12541297
}
1298+
1299+
if (factory.allowsNullKeys) {
1300+
assertFalse(null, mp.replace(null, "value", "new value"))
1301+
assertFalse(mp.containsKey(null))
1302+
1303+
mp.put(null, "null value")
1304+
assertTrue(mp.replace(null, "null value", "new value"))
1305+
assertEquals("new value", mp.get(null))
1306+
} else {
1307+
assertThrows(classOf[NullPointerException], mp.replace(null, "one", "two"))
1308+
}
12551309
}
12561310

12571311
@Test def testUnconditionalReplace(): Unit = {
@@ -1419,6 +1473,15 @@ trait MapTest {
14191473
assertEquals("def", mp.get("nullable"))
14201474
}
14211475
}
1476+
1477+
@Test def additionToKeySet(): Unit = {
1478+
val set = factory.empty[String, String].keySet()
1479+
1480+
expectThrows(classOf[UnsupportedOperationException], set.add("ONE"))
1481+
expectThrows(classOf[UnsupportedOperationException], set.addAll(ju.Arrays.asList("ONE")))
1482+
expectThrows(classOf[UnsupportedOperationException], set.addAll(ju.Arrays.asList(null)))
1483+
expectThrows(classOf[UnsupportedOperationException], set.add(null))
1484+
}
14221485
}
14231486

14241487
object MapTest {

0 commit comments

Comments
 (0)