Skip to content

Commit 083fb8b

Browse files
authored
Merge pull request #9565 from lrytz/t12228
2 parents 5c16b3c + 24476bc commit 083fb8b

File tree

10 files changed

+147
-15
lines changed

10 files changed

+147
-15
lines changed

src/library/scala/collection/Map.scala

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,44 @@ trait Map[K, +V]
2929

3030
def canEqual(that: Any): Boolean = true
3131

32+
/**
33+
* Equality of maps is implemented using the lookup method [[get]]. This method returns `true` if
34+
* - the argument `o` is a `Map`,
35+
* - the two maps have the same [[size]], and
36+
* - for every `(key, value)` pair in this map, `other.get(key) == Some(value)`.
37+
*
38+
* The implementation of `equals` checks the [[canEqual]] method, so subclasses of `Map` can narrow down the equality
39+
* to specific map types. The `Map` implementations in the standard library can all be compared, their `canEqual`
40+
* methods return `true`.
41+
*
42+
* Note: The `equals` method only respects the equality laws (symmetry, transitivity) if the two maps use the same
43+
* key equivalence function in their lookup operation. For example, the key equivalence operation in a
44+
* [[scala.collection.immutable.TreeMap]] is defined by its ordering. Comparing a `TreeMap` with a `HashMap` leads
45+
* to unexpected results if `ordering.equiv(k1, k2)` (used for lookup in `TreeMap`) is different from `k1 == k2`
46+
* (used for lookup in `HashMap`).
47+
*
48+
* {{{
49+
* scala> import scala.collection.immutable._
50+
* scala> val ord: Ordering[String] = _ compareToIgnoreCase _
51+
*
52+
* scala> TreeMap("A" -> 1)(ord) == HashMap("a" -> 1)
53+
* val res0: Boolean = false
54+
*
55+
* scala> HashMap("a" -> 1) == TreeMap("A" -> 1)(ord)
56+
* val res1: Boolean = true
57+
* }}}
58+
*
59+
*
60+
* @param o The map to which this map is compared
61+
* @return `true` if the two maps are equal according to the description
62+
*/
3263
override def equals(o: Any): Boolean =
3364
(this eq o.asInstanceOf[AnyRef]) || (o match {
3465
case map: Map[K, _] if map.canEqual(this) =>
35-
(this.size == map.size) &&
36-
this.forall(kv => map.getOrElse(kv._1, Map.DefaultSentinelFn()) == kv._2)
66+
(this.size == map.size) && {
67+
try this.forall(kv => map.getOrElse(kv._1, Map.DefaultSentinelFn()) == kv._2)
68+
catch { case _: ClassCastException => false } // PR #9565 / scala/bug#12228
69+
}
3770
case _ =>
3871
false
3972
})

src/library/scala/collection/Set.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,44 @@ trait Set[A]
2828

2929
def canEqual(that: Any) = true
3030

31+
/**
32+
* Equality of sets is implemented using the lookup method [[contains]]. This method returns `true` if
33+
* - the argument `that` is a `Set`,
34+
* - the two sets have the same [[size]], and
35+
* - for every `element` this set, `other.contains(element) == true`.
36+
*
37+
* The implementation of `equals` checks the [[canEqual]] method, so subclasses of `Set` can narrow down the equality
38+
* to specific set types. The `Set` implementations in the standard library can all be compared, their `canEqual`
39+
* methods return `true`.
40+
*
41+
* Note: The `equals` method only respects the equality laws (symmetry, transitivity) if the two sets use the same
42+
* element equivalence function in their lookup operation. For example, the element equivalence operation in a
43+
* [[scala.collection.immutable.TreeSet]] is defined by its ordering. Comparing a `TreeSet` with a `HashSet` leads
44+
* to unexpected results if `ordering.equiv(e1, e2)` (used for lookup in `TreeSet`) is different from `e1 == e2`
45+
* (used for lookup in `HashSet`).
46+
*
47+
* {{{
48+
* scala> import scala.collection.immutable._
49+
* scala> val ord: Ordering[String] = _ compareToIgnoreCase _
50+
*
51+
* scala> TreeSet("A")(ord) == HashSet("a")
52+
* val res0: Boolean = false
53+
*
54+
* scala> HashSet("a") == TreeSet("A")(ord)
55+
* val res1: Boolean = true
56+
* }}}
57+
*
58+
*
59+
* @param that The set to which this set is compared
60+
* @return `true` if the two sets are equal according to the description
61+
*/
3162
override def equals(that: Any): Boolean =
3263
(this eq that.asInstanceOf[AnyRef]) || (that match {
3364
case set: Set[A] if set.canEqual(this) =>
34-
(this.size == set.size) && this.subsetOf(set)
65+
(this.size == set.size) && {
66+
try this.subsetOf(set)
67+
catch { case _: ClassCastException => false } // PR #9565 / scala/bug#12228
68+
}
3569
case _ =>
3670
false
3771
})

src/library/scala/collection/SortedMap.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@ trait SortedMap[K, +V]
3030

3131
override def equals(that: Any): Boolean = that match {
3232
case _ if this eq that.asInstanceOf[AnyRef] => true
33-
case sm: SortedMap[k, v] if sm.ordering == this.ordering =>
33+
case sm: SortedMap[K, _] if sm.ordering == this.ordering =>
3434
(sm canEqual this) &&
3535
(this.size == sm.size) && {
3636
val i1 = this.iterator
3737
val i2 = sm.iterator
3838
var allEqual = true
39-
while (allEqual && i1.hasNext)
40-
allEqual = i1.next() == i2.next()
39+
while (allEqual && i1.hasNext) {
40+
val kv1 = i1.next()
41+
val kv2 = i2.next()
42+
allEqual = ordering.equiv(kv1._1, kv2._1) && kv1._2 == kv2._2
43+
}
4144
allEqual
4245
}
4346
case _ => super.equals(that)

src/library/scala/collection/SortedSet.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ trait SortedSet[A] extends Set[A]
2929

3030
override def equals(that: Any): Boolean = that match {
3131
case _ if this eq that.asInstanceOf[AnyRef] => true
32-
case ss: SortedSet[_] if ss.ordering == this.ordering =>
32+
case ss: SortedSet[A] if ss.ordering == this.ordering =>
3333
(ss canEqual this) &&
3434
(this.size == ss.size) && {
3535
val i1 = this.iterator
3636
val i2 = ss.iterator
3737
var allEqual = true
3838
while (allEqual && i1.hasNext)
39-
allEqual = i1.next() == i2.next()
39+
allEqual = ordering.equiv(i1.next(), i2.next())
4040
allEqual
4141
}
4242
case _ =>

src/library/scala/collection/immutable/HashMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ final class HashMap[K, +V] private[immutable] (private[immutable] val rootNode:
254254

255255
override def equals(that: Any): Boolean =
256256
that match {
257-
case map: HashMap[K, V] => (this eq map) || (this.rootNode == map.rootNode)
257+
case map: HashMap[_, _] => (this eq map) || (this.rootNode == map.rootNode)
258258
case _ => super.equals(that)
259259
}
260260

src/library/scala/collection/immutable/HashSet.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ final class HashSet[A] private[immutable](private[immutable] val rootNode: Bitma
176176

177177
override def equals(that: Any): Boolean =
178178
that match {
179-
case set: HashSet[A] => (this eq set) || (this.rootNode == set.rootNode)
179+
case set: HashSet[_] => (this eq set) || (this.rootNode == set.rootNode)
180180
case _ => super.equals(that)
181181
}
182182

src/library/scala/collection/immutable/LongMap.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ object LongMap {
6363
private[immutable] case object Nil extends LongMap[Nothing] {
6464
// Important, don't remove this! See IntMap for explanation.
6565
override def equals(that : Any) = that match {
66-
case (that: AnyRef) if (this eq that) => true
67-
case (that: LongMap[_]) => false // The only empty LongMaps are eq Nil
68-
case that => super.equals(that)
66+
case _: this.type => true
67+
case _: LongMap[_] => false // The only empty LongMaps are eq Nil
68+
case _ => super.equals(that)
6969
}
7070
}
7171

src/library/scala/collection/immutable/TreeMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ final class TreeMap[K, +V] private (private val tree: RB.Tree[K, V])(implicit va
283283
}
284284
}
285285
override def equals(obj: Any): Boolean = obj match {
286-
case that: TreeMap[K, V] if ordering == that.ordering => RB.entriesEqual(tree, that.tree)
286+
case that: TreeMap[K, _] if ordering == that.ordering => RB.entriesEqual(tree, that.tree)
287287
case _ => super.equals(obj)
288288
}
289289

test/junit/scala/collection/MapTest.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,9 @@ class MapTest {
123123
check(mutable.CollisionProofHashMap(1 -> 1))
124124
}
125125

126+
@Test
127+
def t12228(): Unit = {
128+
assertFalse(Set("") == immutable.BitSet(1))
129+
assertFalse(Map("" -> 2) == scala.collection.immutable.LongMap(1L -> 2))
130+
}
126131
}

test/junit/scala/collection/SortedSetMapEqualsTest.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package scala.collection
22

3-
import org.junit.{Assert, Test}, Assert.assertEquals
3+
import org.junit.{Assert, Test}
4+
import Assert.{assertEquals, assertNotEquals}
45

56
class SortedSetMapEqualsTest {
67
@Test
@@ -68,4 +69,60 @@ class SortedSetMapEqualsTest {
6869
}
6970
assertEquals(m1, m2)
7071
}
72+
73+
@Test
74+
def compareSortedMapKeysByOrdering(): Unit = {
75+
val ord: Ordering[String] = _ compareToIgnoreCase _
76+
77+
val itm1 = scala.collection.immutable.TreeMap("A" -> "2")(ord)
78+
val itm2 = scala.collection.immutable.TreeMap("a" -> "2")(ord)
79+
val mtm1 = scala.collection.mutable.TreeMap("A" -> "2")(ord)
80+
val mtm2 = scala.collection.mutable.TreeMap("a" -> "2")(ord)
81+
82+
assertEquals(itm1, itm2)
83+
assertEquals(mtm1, mtm2)
84+
85+
assertEquals(itm1, mtm2)
86+
assertEquals(mtm1, itm2)
87+
88+
val m1 = Map("A" -> "2")
89+
val m2 = Map("a" -> "2")
90+
91+
for (m <- List(m1, m2); tm <- List[Map[String, String]](itm1, itm2, mtm1, mtm2))
92+
assertEquals(m, tm) // uses keys in `m` to look up values in `tm`, which always succeeds
93+
94+
assertEquals(itm1, m1)
95+
assertEquals(mtm1, m1)
96+
97+
assertNotEquals(itm2, m1) // uses key in `itm2` ("a") to look up in `m1`, which fails
98+
assertNotEquals(mtm2, m1)
99+
}
100+
101+
@Test
102+
def compareSortedSetsByOrdering(): Unit = {
103+
val ord: Ordering[String] = _ compareToIgnoreCase _
104+
105+
val its1 = scala.collection.immutable.TreeSet("A")(ord)
106+
val its2 = scala.collection.immutable.TreeSet("a")(ord)
107+
val mts1 = scala.collection.mutable.TreeSet("A")(ord)
108+
val mts2 = scala.collection.mutable.TreeSet("a")(ord)
109+
110+
assertEquals(its1, its2)
111+
assertEquals(mts1, mts2)
112+
113+
assertEquals(its1, mts2)
114+
assertEquals(mts1, its2)
115+
116+
val s1 = Set("A")
117+
val s2 = Set("a")
118+
119+
for (m <- List(s1, s2); tm <- List[Set[String]](its1, its2, mts1, mts2))
120+
assertEquals(m, tm) // uses keys in `m` to look up values in `tm`, which always succeeds
121+
122+
assertEquals(its1, s1)
123+
assertEquals(mts1, s1)
124+
125+
assertNotEquals(its2, s1) // uses key in `its2` ("a") to look up in `s1`, which fails
126+
assertNotEquals(mts2, s1)
127+
}
71128
}

0 commit comments

Comments
 (0)