Skip to content

Commit 1ce6394

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-11867] Add save/load for kmeans and naive bayes
https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin <yinxusen@gmail.com> Closes apache#9849 from yinxusen/SPARK-11867. (cherry picked from commit 3e1d120) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 60d9375 commit 1ce6394

File tree

4 files changed

+195
-28
lines changed

4 files changed

+195
-28
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
package org.apache.spark.ml.classification
1919

20+
import org.apache.hadoop.fs.Path
21+
2022
import org.apache.spark.SparkException
21-
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.annotation.{Experimental, Since}
2224
import org.apache.spark.ml.PredictorParams
2325
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
24-
import org.apache.spark.ml.util.Identifiable
25-
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel}
26+
import org.apache.spark.ml.util._
27+
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
28+
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
2629
import org.apache.spark.mllib.linalg._
2730
import org.apache.spark.mllib.regression.LabeledPoint
2831
import org.apache.spark.rdd.RDD
@@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
7275
@Experimental
7376
class NaiveBayes(override val uid: String)
7477
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
75-
with NaiveBayesParams {
78+
with NaiveBayesParams with DefaultParamsWritable {
7679

7780
def this() = this(Identifiable.randomUID("nb"))
7881

@@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
102105
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
103106
}
104107

108+
@Since("1.6.0")
109+
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
110+
111+
@Since("1.6.0")
112+
override def load(path: String): NaiveBayes = super.load(path)
113+
}
114+
105115
/**
106116
* :: Experimental ::
107117
* Model produced by [[NaiveBayes]]
@@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
114124
override val uid: String,
115125
val pi: Vector,
116126
val theta: Matrix)
117-
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
127+
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
128+
with NaiveBayesParams with MLWritable {
118129

119130
import OldNaiveBayes.{Bernoulli, Multinomial}
120131

@@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
203214
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
204215
}
205216

217+
@Since("1.6.0")
218+
override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
206219
}
207220

208-
private[ml] object NaiveBayesModel {
221+
@Since("1.6.0")
222+
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
209223

210224
/** Convert a model from the old API */
211-
def fromOld(
225+
private[ml] def fromOld(
212226
oldModel: OldNaiveBayesModel,
213227
parent: NaiveBayes): NaiveBayesModel = {
214228
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
@@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
218232
oldModel.theta.flatten, true)
219233
new NaiveBayesModel(uid, pi, theta)
220234
}
235+
236+
@Since("1.6.0")
237+
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
238+
239+
@Since("1.6.0")
240+
override def load(path: String): NaiveBayesModel = super.load(path)
241+
242+
/** [[MLWriter]] instance for [[NaiveBayesModel]] */
243+
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
244+
245+
private case class Data(pi: Vector, theta: Matrix)
246+
247+
override protected def saveImpl(path: String): Unit = {
248+
// Save metadata and Params
249+
DefaultParamsWriter.saveMetadata(instance, path, sc)
250+
// Save model data: pi, theta
251+
val data = Data(instance.pi, instance.theta)
252+
val dataPath = new Path(path, "data").toString
253+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
254+
}
255+
}
256+
257+
private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
258+
259+
/** Checked against metadata when loading model */
260+
private val className = classOf[NaiveBayesModel].getName
261+
262+
override def load(path: String): NaiveBayesModel = {
263+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
264+
265+
val dataPath = new Path(path, "data").toString
266+
val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
267+
val pi = data.getAs[Vector](0)
268+
val theta = data.getAs[Matrix](1)
269+
val model = new NaiveBayesModel(metadata.uid, pi, theta)
270+
271+
DefaultParamsReader.getAndSetParams(model, metadata)
272+
model
273+
}
274+
}
221275
}

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20-
import org.apache.spark.annotation.{Since, Experimental}
21-
import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
20+
import org.apache.hadoop.fs.Path
21+
22+
import org.apache.spark.annotation.{Experimental, Since}
2223
import org.apache.spark.ml.param.shared._
23-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
24+
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
25+
import org.apache.spark.ml.util._
2426
import org.apache.spark.ml.{Estimator, Model}
2527
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
2628
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2729
import org.apache.spark.sql.functions.{col, udf}
2830
import org.apache.spark.sql.types.{IntegerType, StructType}
2931
import org.apache.spark.sql.{DataFrame, Row}
3032

31-
3233
/**
3334
* Common params for KMeans and KMeansModel
3435
*/
@@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
9495
@Experimental
9596
class KMeansModel private[ml] (
9697
@Since("1.5.0") override val uid: String,
97-
private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
98+
private val parentModel: MLlibKMeansModel)
99+
extends Model[KMeansModel] with KMeansParams with MLWritable {
98100

99101
@Since("1.5.0")
100102
override def copy(extra: ParamMap): KMeansModel = {
@@ -129,6 +131,52 @@ class KMeansModel private[ml] (
129131
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
130132
parentModel.computeCost(data)
131133
}
134+
135+
@Since("1.6.0")
136+
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
137+
}
138+
139+
@Since("1.6.0")
140+
object KMeansModel extends MLReadable[KMeansModel] {
141+
142+
@Since("1.6.0")
143+
override def read: MLReader[KMeansModel] = new KMeansModelReader
144+
145+
@Since("1.6.0")
146+
override def load(path: String): KMeansModel = super.load(path)
147+
148+
/** [[MLWriter]] instance for [[KMeansModel]] */
149+
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
150+
151+
private case class Data(clusterCenters: Array[Vector])
152+
153+
override protected def saveImpl(path: String): Unit = {
154+
// Save metadata and Params
155+
DefaultParamsWriter.saveMetadata(instance, path, sc)
156+
// Save model data: cluster centers
157+
val data = Data(instance.clusterCenters)
158+
val dataPath = new Path(path, "data").toString
159+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
160+
}
161+
}
162+
163+
private class KMeansModelReader extends MLReader[KMeansModel] {
164+
165+
/** Checked against metadata when loading model */
166+
private val className = classOf[KMeansModel].getName
167+
168+
override def load(path: String): KMeansModel = {
169+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
170+
171+
val dataPath = new Path(path, "data").toString
172+
val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
173+
val clusterCenters = data.getAs[Seq[Vector]](0).toArray
174+
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
175+
176+
DefaultParamsReader.getAndSetParams(model, metadata)
177+
model
178+
}
179+
}
132180
}
133181

134182
/**
@@ -141,7 +189,7 @@ class KMeansModel private[ml] (
141189
@Experimental
142190
class KMeans @Since("1.5.0") (
143191
@Since("1.5.0") override val uid: String)
144-
extends Estimator[KMeansModel] with KMeansParams {
192+
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
145193

146194
setDefault(
147195
k -> 2,
@@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
210258
}
211259
}
212260

261+
@Since("1.6.0")
262+
object KMeans extends DefaultParamsReadable[KMeans] {
263+
264+
@Since("1.6.0")
265+
override def load(path: String): KMeans = super.load(path)
266+
}
267+

mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.ml.param.ParamsSuite
24-
import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
24+
import org.apache.spark.ml.util.DefaultReadWriteTest
25+
import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
26+
import org.apache.spark.mllib.classification.NaiveBayesSuite._
2527
import org.apache.spark.mllib.linalg._
2628
import org.apache.spark.mllib.util.MLlibTestSparkContext
2729
import org.apache.spark.mllib.util.TestingUtils._
28-
import org.apache.spark.mllib.classification.NaiveBayesSuite._
29-
import org.apache.spark.sql.DataFrame
30-
import org.apache.spark.sql.Row
30+
import org.apache.spark.sql.{DataFrame, Row}
31+
32+
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
33+
34+
@transient var dataset: DataFrame = _
35+
36+
override def beforeAll(): Unit = {
37+
super.beforeAll()
38+
39+
val pi = Array(0.5, 0.1, 0.4).map(math.log)
40+
val theta = Array(
41+
Array(0.70, 0.10, 0.10, 0.10), // label 0
42+
Array(0.10, 0.70, 0.10, 0.10), // label 1
43+
Array(0.10, 0.10, 0.70, 0.10) // label 2
44+
).map(_.map(math.log))
3145

32-
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
46+
dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
47+
}
3348

3449
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
3550
val numOfErrorPredictions = predictionAndLabels.collect().count {
@@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
161176
.select("features", "probability")
162177
validateProbabilities(featureAndProbabilities, model, "bernoulli")
163178
}
179+
180+
test("read/write") {
181+
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
182+
assert(model.pi === model2.pi)
183+
assert(model.theta === model2.theta)
184+
}
185+
val nb = new NaiveBayes()
186+
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
187+
}
188+
}
189+
190+
object NaiveBayesSuite {
191+
192+
/**
193+
* Mapping from all Params to valid settings which differ from the defaults.
194+
* This is useful for tests which need to exercise all Params, such as save/load.
195+
* This excludes input columns to simplify some tests.
196+
*/
197+
val allParamSettings: Map[String, Any] = Map(
198+
"predictionCol" -> "myPrediction",
199+
"smoothing" -> 0.1
200+
)
164201
}

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,15 @@
1818
package org.apache.spark.ml.clustering
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.util.DefaultReadWriteTest
2122
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
2223
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2324
import org.apache.spark.mllib.util.MLlibTestSparkContext
2425
import org.apache.spark.sql.{DataFrame, SQLContext}
2526

2627
private[clustering] case class TestRow(features: Vector)
2728

28-
object KMeansSuite {
29-
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
30-
val sc = sql.sparkContext
31-
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
32-
.map(v => new TestRow(v))
33-
sql.createDataFrame(rdd)
34-
}
35-
}
36-
37-
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
29+
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3830

3931
final val k = 5
4032
@transient var dataset: DataFrame = _
@@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
10698
assert(clusters === Set(0, 1, 2, 3, 4))
10799
assert(model.computeCost(dataset) < 0.1)
108100
}
101+
102+
test("read/write") {
103+
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
104+
assert(model.clusterCenters === model2.clusterCenters)
105+
}
106+
val kmeans = new KMeans()
107+
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
108+
}
109+
}
110+
111+
object KMeansSuite {
112+
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
113+
val sc = sql.sparkContext
114+
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
115+
.map(v => new TestRow(v))
116+
sql.createDataFrame(rdd)
117+
}
118+
119+
/**
120+
* Mapping from all Params to valid settings which differ from the defaults.
121+
* This is useful for tests which need to exercise all Params, such as save/load.
122+
* This excludes input columns to simplify some tests.
123+
*/
124+
val allParamSettings: Map[String, Any] = Map(
125+
"predictionCol" -> "myPrediction",
126+
"k" -> 3,
127+
"maxIter" -> 2,
128+
"tol" -> 0.01
129+
)
109130
}

0 commit comments

Comments
 (0)