Skip to content

Commit 8623c2b

Browse files
committed
Merge pull request scala#4431 from adriaanm/rebase-4379
Patmat: efficient reasoning about mutual exclusion
2 parents 6e7b326 + d44a86f commit 8623c2b

File tree

7 files changed

+1034
-44
lines changed

7 files changed

+1034
-44
lines changed

src/compiler/scala/tools/nsc/transform/patmat/Logic.scala

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package tools.nsc.transform.patmat
1010
import scala.language.postfixOps
1111
import scala.collection.mutable
1212
import scala.reflect.internal.util.{NoPosition, Position, Statistics, HashSet}
13+
import scala.tools.nsc.Global
1314

1415
trait Logic extends Debugging {
1516
import PatternMatchingStats._
@@ -90,6 +91,8 @@ trait Logic extends Debugging {
9091
// compute the domain and return it (call registerNull first!)
9192
def domainSyms: Option[Set[Sym]]
9293

94+
def groupedDomains: List[Set[Sym]]
95+
9396
// the symbol for this variable being equal to its statically known type
9497
// (only available if registerEquality has been called for that type before)
9598
def symForStaticTp: Option[Sym]
@@ -118,6 +121,9 @@ trait Logic extends Debugging {
118121

119122
final case class Not(a: Prop) extends Prop
120123

124+
// mutually exclusive (i.e., not more than one symbol is set)
125+
final case class AtMostOne(ops: List[Sym]) extends Prop
126+
121127
case object True extends Prop
122128
case object False extends Prop
123129

@@ -192,7 +198,8 @@ trait Logic extends Debugging {
192198
case Not(negated) => negationNormalFormNot(negated)
193199
case True
194200
| False
195-
| (_: Sym) => p
201+
| (_: Sym)
202+
| (_: AtMostOne) => p
196203
}
197204

198205
def simplifyProp(p: Prop): Prop = p match {
@@ -252,6 +259,7 @@ trait Logic extends Debugging {
252259
case Not(a) => apply(a)
253260
case Eq(a, b) => applyVar(a); applyConst(b)
254261
case s: Sym => applySymbol(s)
262+
case AtMostOne(ops) => ops.foreach(applySymbol)
255263
case _ =>
256264
}
257265
def applyVar(x: Var): Unit = {}
@@ -374,7 +382,23 @@ trait Logic extends Debugging {
374382
// when sym is true, what must hold...
375383
implied foreach (impliedSym => addAxiom(Or(Not(sym), impliedSym)))
376384
// ... and what must not?
377-
excluded foreach (excludedSym => addAxiom(Or(Not(sym), Not(excludedSym))))
385+
excluded foreach {
386+
excludedSym =>
387+
val related = Set(sym, excludedSym)
388+
val exclusive = v.groupedDomains.exists {
389+
domain => related subsetOf domain.toSet
390+
}
391+
392+
// TODO: populate `v.exclusiveDomains` with `Set`s from the start, and optimize to:
393+
// val exclusive = v.exclusiveDomains.exists { inDomain => inDomain(sym) && inDomain(excludedSym) }
394+
if (!exclusive)
395+
addAxiom(Or(Not(sym), Not(excludedSym)))
396+
}
397+
}
398+
399+
// all symbols in a domain are mutually exclusive
400+
v.groupedDomains.foreach {
401+
syms => if (syms.size > 1) addAxiom(AtMostOne(syms.toList))
378402
}
379403
}
380404

@@ -449,7 +473,9 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
449473
// once we go to run-time checks (on Const's), convert them to checkable types
450474
// TODO: there seems to be bug for singleton domains (variable does not show up in model)
451475
lazy val domain: Option[Set[Const]] = {
452-
val subConsts = enumerateSubtypes(staticTp).map{ tps =>
476+
val subConsts =
477+
enumerateSubtypes(staticTp, grouped = false)
478+
.headOption.map { tps =>
453479
tps.toSet[Type].map{ tp =>
454480
val domainC = TypeConst(tp)
455481
registerEquality(domainC)
@@ -467,6 +493,15 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
467493
observed(); allConsts
468494
}
469495

496+
lazy val groupedDomains: List[Set[Sym]] = {
497+
val subtypes = enumerateSubtypes(staticTp, grouped = true)
498+
subtypes.map {
499+
subTypes =>
500+
val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).toSet
501+
if (mayBeNull) syms + symForEqualsTo(NullConst) else syms
502+
}.filter(_.nonEmpty)
503+
}
504+
470505
// populate equalitySyms
471506
// don't care about the result, but want only one fresh symbol per distinct constant c
472507
def registerEquality(c: Const): Unit = {ensureCanModify(); symForEqualsTo getOrElseUpdate(c, Sym(this, c))}

src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -95,58 +95,84 @@ trait TreeAndTypeAnalysis extends Debugging {
9595
val typer: Typer
9696

9797
// TODO: domain of other feasibly enumerable built-in types (char?)
98-
def enumerateSubtypes(tp: Type): Option[List[Type]] =
98+
def enumerateSubtypes(tp: Type, grouped: Boolean): List[List[Type]] =
9999
tp.typeSymbol match {
100100
// TODO case _ if tp.isTupleType => // recurse into component types?
101-
case UnitClass =>
102-
Some(List(UnitTpe))
103-
case BooleanClass =>
104-
Some(ConstantTrue :: ConstantFalse :: Nil)
101+
case UnitClass if !grouped =>
102+
List(List(UnitTpe))
103+
case BooleanClass if !grouped =>
104+
List(ConstantTrue :: ConstantFalse :: Nil)
105105
// TODO case _ if tp.isTupleType => // recurse into component types
106-
case modSym: ModuleClassSymbol =>
107-
Some(List(tp))
106+
case modSym: ModuleClassSymbol if !grouped =>
107+
List(List(tp))
108108
case sym: RefinementClassSymbol =>
109-
val parentSubtypes: List[Option[List[Type]]] = tp.parents.map(parent => enumerateSubtypes(parent))
110-
if (parentSubtypes exists (_.isDefined))
109+
val parentSubtypes = tp.parents.flatMap(parent => enumerateSubtypes(parent, grouped))
110+
if (parentSubtypes exists (_.nonEmpty)) {
111111
// If any of the parents is enumerable, then the refinement type is enumerable.
112-
Some(
113-
// We must only include subtypes of the parents that conform to `tp`.
114-
// See neg/virtpatmat_exhaust_compound.scala for an example.
115-
parentSubtypes flatMap (_.getOrElse(Nil)) filter (_ <:< tp)
116-
)
117-
else None
112+
// We must only include subtypes of the parents that conform to `tp`.
113+
// See neg/virtpatmat_exhaust_compound.scala for an example.
114+
parentSubtypes map (_.filter(_ <:< tp))
115+
}
116+
else Nil
118117
// make sure it's not a primitive, else (5: Byte) match { case 5 => ... } sees no Byte
119118
case sym if sym.isSealed =>
120-
val subclasses = debug.patmatResult(s"enum $sym sealed, subclasses")(
121-
// symbols which are both sealed and abstract need not be covered themselves, because
122-
// all of their children must be and they cannot otherwise be created.
123-
sym.sealedDescendants.toList
124-
sortBy (_.sealedSortName)
125-
filterNot (x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x))
126-
)
127119

128120
val tpApprox = typer.infer.approximateAbstracts(tp)
129121
val pre = tpApprox.prefix
130122

131-
Some(debug.patmatResult(s"enum sealed tp=$tp, tpApprox=$tpApprox as") {
132-
// valid subtypes are turned into checkable types, as we are entering the realm of the dynamic
133-
subclasses flatMap { sym =>
123+
def filterChildren(children: List[Symbol]): List[Type] = {
124+
children flatMap { sym =>
134125
// have to filter out children which cannot match: see ticket #3683 for an example
135126
// compare to the fully known type `tp` (modulo abstract types),
136127
// so that we can rule out stuff like: sealed trait X[T]; class XInt extends X[Int] --> XInt not valid when enumerating X[String]
137128
// however, must approximate abstract types in
138129

139-
val memberType = nestedMemberType(sym, pre, tpApprox.typeSymbol.owner)
140-
val subTp = appliedType(memberType, sym.typeParams.map(_ => WildcardType))
130+
val memberType = nestedMemberType(sym, pre, tpApprox.typeSymbol.owner)
131+
val subTp = appliedType(memberType, sym.typeParams.map(_ => WildcardType))
141132
val subTpApprox = typer.infer.approximateAbstracts(subTp) // TODO: needed?
142133
// debug.patmat("subtp"+(subTpApprox <:< tpApprox, subTpApprox, tpApprox))
143134
if (subTpApprox <:< tpApprox) Some(checkableType(subTp))
144135
else None
145136
}
146-
})
137+
}
138+
139+
if(grouped) {
140+
def enumerateChildren(sym: Symbol) = {
141+
sym.children.toList
142+
.sortBy(_.sealedSortName)
143+
.filterNot(x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x))
144+
}
145+
146+
// enumerate only direct subclasses,
147+
// subclasses of subclasses are enumerated in the next iteration
148+
// and added to a new group
149+
def groupChildren(wl: List[Symbol],
150+
acc: List[List[Type]]): List[List[Type]] = wl match {
151+
case hd :: tl =>
152+
val children = enumerateChildren(hd)
153+
groupChildren(tl ++ children, acc :+ filterChildren(children))
154+
case Nil => acc
155+
}
156+
157+
groupChildren(sym :: Nil, Nil)
158+
} else {
159+
val subclasses = debug.patmatResult(s"enum $sym sealed, subclasses")(
160+
// symbols which are both sealed and abstract need not be covered themselves, because
161+
// all of their children must be and they cannot otherwise be created.
162+
sym.sealedDescendants.toList
163+
sortBy (_.sealedSortName)
164+
filterNot (x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x))
165+
)
166+
167+
List(debug.patmatResult(s"enum sealed tp=$tp, tpApprox=$tpApprox as") {
168+
// valid subtypes are turned into checkable types, as we are entering the realm of the dynamic
169+
filterChildren(subclasses)
170+
})
171+
}
172+
147173
case sym =>
148174
debug.patmat("enum unsealed "+ ((tp, sym, sym.isSealed, isPrimitiveValueClass(sym))))
149-
None
175+
Nil
150176
}
151177

152178
// approximate a type to the static type that is fully checkable at run time,
@@ -176,7 +202,7 @@ trait TreeAndTypeAnalysis extends Debugging {
176202
def uncheckableType(tp: Type): Boolean = {
177203
val checkable = (
178204
(isTupleType(tp) && tupleComponents(tp).exists(tp => !uncheckableType(tp)))
179-
|| enumerateSubtypes(tp).nonEmpty)
205+
|| enumerateSubtypes(tp, grouped = false).nonEmpty)
180206
// if (!checkable) debug.patmat("deemed uncheckable: "+ tp)
181207
!checkable
182208
}

src/compiler/scala/tools/nsc/transform/patmat/Solving.scala

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,22 @@ trait Solving extends Logic {
6565
def size = symbols.size
6666
}
6767

68+
def cnfString(f: Array[Clause]): String
69+
6870
final case class Solvable(cnf: Cnf, symbolMapping: SymbolMapping) {
6971
def ++(other: Solvable) = {
7072
require(this.symbolMapping eq other.symbolMapping)
7173
Solvable(cnf ++ other.cnf, symbolMapping)
7274
}
75+
76+
override def toString: String = {
77+
"Solvable\nLiterals:\n" +
78+
(for {
79+
(lit, sym) <- symbolMapping.symForVar.toSeq.sortBy(_._1)
80+
} yield {
81+
s"$lit -> $sym"
82+
}).mkString("\n") + "Cnf:\n" + cnfString(cnf)
83+
}
7384
}
7485

7586
trait CnfBuilder {
@@ -140,20 +151,23 @@ trait Solving extends Logic {
140151

141152
def apply(p: Prop): Solvable = {
142153

143-
def convert(p: Prop): Lit = {
154+
def convert(p: Prop): Option[Lit] = {
144155
p match {
145156
case And(fv) =>
146-
and(fv.map(convert))
157+
Some(and(fv.flatMap(convert)))
147158
case Or(fv) =>
148-
or(fv.map(convert))
159+
Some(or(fv.flatMap(convert)))
149160
case Not(a) =>
150-
not(convert(a))
161+
convert(a).map(not)
151162
case sym: Sym =>
152-
convertSym(sym)
163+
Some(convertSym(sym))
153164
case True =>
154-
constTrue
165+
Some(constTrue)
155166
case False =>
156-
constFalse
167+
Some(constFalse)
168+
case AtMostOne(ops) =>
169+
atMostOne(ops)
170+
None
157171
case _: Eq =>
158172
throw new MatchError(p)
159173
}
@@ -199,8 +213,57 @@ trait Solving extends Logic {
199213
// no need for auxiliary variable
200214
def not(a: Lit): Lit = -a
201215

216+
/**
217+
* This encoding adds 3n-4 variables auxiliary variables
218+
* to encode that at most 1 symbol can be set.
219+
* See also "Towards an Optimal CNF Encoding of Boolean Cardinality Constraints"
220+
* http://www.carstensinz.de/papers/CP-2005.pdf
221+
*/
222+
def atMostOne(ops: List[Sym]) {
223+
(ops: @unchecked) match {
224+
case hd :: Nil => convertSym(hd)
225+
case x1 :: tail =>
226+
// sequential counter: 3n-4 clauses
227+
// pairwise encoding: n*(n-1)/2 clauses
228+
// thus pays off only if n > 5
229+
if (ops.lengthCompare(5) > 0) {
230+
231+
@inline
232+
def /\(a: Lit, b: Lit) = addClauseProcessed(clause(a, b))
233+
234+
val (mid, xn :: Nil) = tail.splitAt(tail.size - 1)
235+
236+
// 1 <= x1,...,xn <==>
237+
//
238+
// (!x1 \/ s1) /\ (!xn \/ !sn-1) /\
239+
//
240+
// /\
241+
// / \ (!xi \/ si) /\ (!si-1 \/ si) /\ (!xi \/ !si-1)
242+
// 1 < i < n
243+
val s1 = newLiteral()
244+
/\(-convertSym(x1), s1)
245+
val snMinus = mid.foldLeft(s1) {
246+
case (siMinus, sym) =>
247+
val xi = convertSym(sym)
248+
val si = newLiteral()
249+
/\(-xi, si)
250+
/\(-siMinus, si)
251+
/\(-xi, -siMinus)
252+
si
253+
}
254+
/\(-convertSym(xn), -snMinus)
255+
} else {
256+
ops.map(convertSym).combinations(2).foreach {
257+
case a :: b :: Nil =>
258+
addClauseProcessed(clause(-a, -b))
259+
case _ =>
260+
}
261+
}
262+
}
263+
}
264+
202265
// add intermediate variable since we want the formula to be SAT!
203-
addClauseProcessed(clause(convert(p)))
266+
addClauseProcessed(convert(p).toSet)
204267

205268
Solvable(buildCnf, symbolMapping)
206269
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
patmatexhaust-huge.scala:404: warning: match may not be exhaustive.
2+
It would fail on the following inputs: C392, C397
3+
def f(c: C): Int = c match {
4+
^
5+
error: No warnings can be incurred under -Xfatal-warnings.
6+
one warning found
7+
one error found
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-Xfatal-warnings -unchecked -Ypatmat-exhaust-depth off

0 commit comments

Comments
 (0)