Skip to content

Commit d7b3d57

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-11829][ML] Add read/write to estimators under ml.feature (II)
Add read/write support to the following estimators under spark.ml: * ChiSqSelector * PCA * VectorIndexer * Word2Vec Author: Yanbo Liang <ybliang8@gmail.com> Closes apache#9838 from yanboliang/spark-11829. (cherry picked from commit 3b7f056) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 4774897 commit d7b3d57

File tree

9 files changed

+338
-33
lines changed

9 files changed

+338
-33
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.hadoop.fs.Path
21+
22+
import org.apache.spark.annotation.{Experimental, Since}
2123
import org.apache.spark.ml._
2224
import org.apache.spark.ml.attribute.{AttributeGroup, _}
2325
import org.apache.spark.ml.param._
2426
import org.apache.spark.ml.param.shared._
25-
import org.apache.spark.ml.util.Identifiable
26-
import org.apache.spark.ml.util.SchemaUtils
27+
import org.apache.spark.ml.util._
2728
import org.apache.spark.mllib.feature
2829
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2930
import org.apache.spark.mllib.regression.LabeledPoint
@@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params
6061
*/
6162
@Experimental
6263
final class ChiSqSelector(override val uid: String)
63-
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
64+
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable {
6465

6566
def this() = this(Identifiable.randomUID("chiSqSelector"))
6667

@@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String)
9596
override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
9697
}
9798

99+
@Since("1.6.0")
100+
object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] {
101+
102+
@Since("1.6.0")
103+
override def load(path: String): ChiSqSelector = super.load(path)
104+
}
105+
98106
/**
99107
* :: Experimental ::
100108
* Model fitted by [[ChiSqSelector]].
@@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String)
103111
final class ChiSqSelectorModel private[ml] (
104112
override val uid: String,
105113
private val chiSqSelector: feature.ChiSqSelectorModel)
106-
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
114+
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable {
115+
116+
import ChiSqSelectorModel._
117+
118+
/** list of indices to select (filter). Must be ordered asc */
119+
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
107120

108121
/** @group setParam */
109122
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] (
147160
val copied = new ChiSqSelectorModel(uid, chiSqSelector)
148161
copyValues(copied, extra).setParent(parent)
149162
}
163+
164+
@Since("1.6.0")
165+
override def write: MLWriter = new ChiSqSelectorModelWriter(this)
166+
}
167+
168+
@Since("1.6.0")
169+
object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
170+
171+
private[ChiSqSelectorModel]
172+
class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter {
173+
174+
private case class Data(selectedFeatures: Seq[Int])
175+
176+
override protected def saveImpl(path: String): Unit = {
177+
DefaultParamsWriter.saveMetadata(instance, path, sc)
178+
val data = Data(instance.selectedFeatures.toSeq)
179+
val dataPath = new Path(path, "data").toString
180+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
181+
}
182+
}
183+
184+
private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] {
185+
186+
private val className = classOf[ChiSqSelectorModel].getName
187+
188+
override def load(path: String): ChiSqSelectorModel = {
189+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
190+
val dataPath = new Path(path, "data").toString
191+
val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
192+
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
193+
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
194+
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
195+
DefaultParamsReader.getAndSetParams(model, metadata)
196+
model
197+
}
198+
}
199+
200+
@Since("1.6.0")
201+
override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader
202+
203+
@Since("1.6.0")
204+
override def load(path: String): ChiSqSelectorModel = super.load(path)
150205
}

mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.hadoop.fs.Path
21+
22+
import org.apache.spark.annotation.{Experimental, Since}
2123
import org.apache.spark.ml._
2224
import org.apache.spark.ml.param._
2325
import org.apache.spark.ml.param.shared._
24-
import org.apache.spark.ml.util.Identifiable
26+
import org.apache.spark.ml.util._
2527
import org.apache.spark.mllib.feature
26-
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
28+
import org.apache.spark.mllib.linalg._
2729
import org.apache.spark.sql._
2830
import org.apache.spark.sql.functions._
2931
import org.apache.spark.sql.types.{StructField, StructType}
@@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
4951
* PCA trains a model to project vectors to a low-dimensional space using PCA.
5052
*/
5153
@Experimental
52-
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
54+
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
55+
with DefaultParamsWritable {
5356

5457
def this() = this(Identifiable.randomUID("pca"))
5558

@@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
8689
override def copy(extra: ParamMap): PCA = defaultCopy(extra)
8790
}
8891

92+
@Since("1.6.0")
93+
object PCA extends DefaultParamsReadable[PCA] {
94+
95+
@Since("1.6.0")
96+
override def load(path: String): PCA = super.load(path)
97+
}
98+
8999
/**
90100
* :: Experimental ::
91101
* Model fitted by [[PCA]].
@@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
94104
class PCAModel private[ml] (
95105
override val uid: String,
96106
pcaModel: feature.PCAModel)
97-
extends Model[PCAModel] with PCAParams {
107+
extends Model[PCAModel] with PCAParams with MLWritable {
108+
109+
import PCAModel._
110+
111+
/** a principal components Matrix. Each column is one principal component. */
112+
val pc: DenseMatrix = pcaModel.pc
98113

99114
/** @group setParam */
100115
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -127,4 +142,46 @@ class PCAModel private[ml] (
127142
val copied = new PCAModel(uid, pcaModel)
128143
copyValues(copied, extra).setParent(parent)
129144
}
145+
146+
@Since("1.6.0")
147+
override def write: MLWriter = new PCAModelWriter(this)
148+
}
149+
150+
@Since("1.6.0")
151+
object PCAModel extends MLReadable[PCAModel] {
152+
153+
private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
154+
155+
private case class Data(k: Int, pc: DenseMatrix)
156+
157+
override protected def saveImpl(path: String): Unit = {
158+
DefaultParamsWriter.saveMetadata(instance, path, sc)
159+
val data = Data(instance.getK, instance.pc)
160+
val dataPath = new Path(path, "data").toString
161+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
162+
}
163+
}
164+
165+
private class PCAModelReader extends MLReader[PCAModel] {
166+
167+
private val className = classOf[PCAModel].getName
168+
169+
override def load(path: String): PCAModel = {
170+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
171+
val dataPath = new Path(path, "data").toString
172+
val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
173+
.select("k", "pc")
174+
.head()
175+
val oldModel = new feature.PCAModel(k, pc)
176+
val model = new PCAModel(metadata.uid, oldModel)
177+
DefaultParamsReader.getAndSetParams(model, metadata)
178+
model
179+
}
180+
}
181+
182+
@Since("1.6.0")
183+
override def read: MLReader[PCAModel] = new PCAModelReader
184+
185+
@Since("1.6.0")
186+
override def load(path: String): PCAModel = super.load(path)
130187
}

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ import java.util.{Map => JMap}
2222

2323
import scala.collection.JavaConverters._
2424

25-
import org.apache.spark.annotation.Experimental
25+
import org.apache.hadoop.fs.Path
26+
27+
import org.apache.spark.annotation.{Experimental, Since}
2628
import org.apache.spark.ml.{Estimator, Model}
2729
import org.apache.spark.ml.attribute._
28-
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
30+
import org.apache.spark.ml.param._
2931
import org.apache.spark.ml.param.shared._
30-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
32+
import org.apache.spark.ml.util._
3133
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
3234
import org.apache.spark.sql.{DataFrame, Row}
3335
import org.apache.spark.sql.functions.udf
@@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
9395
*/
9496
@Experimental
9597
class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
96-
with VectorIndexerParams {
98+
with VectorIndexerParams with DefaultParamsWritable {
9799

98100
def this() = this(Identifiable.randomUID("vecIdx"))
99101

@@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
136138
override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
137139
}
138140

139-
private object VectorIndexer {
141+
@Since("1.6.0")
142+
object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {
143+
144+
@Since("1.6.0")
145+
override def load(path: String): VectorIndexer = super.load(path)
140146

141147
/**
142148
* Helper class for tracking unique values for each feature.
@@ -146,7 +152,7 @@ private object VectorIndexer {
146152
* @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures.
147153
* @param maxCategories This class caps the number of unique values collected at maxCategories.
148154
*/
149-
class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
155+
private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
150156
extends Serializable {
151157

152158
/** featureValueSets[feature index] = set of unique values */
@@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] (
252258
override val uid: String,
253259
val numFeatures: Int,
254260
val categoryMaps: Map[Int, Map[Double, Int]])
255-
extends Model[VectorIndexerModel] with VectorIndexerParams {
261+
extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable {
262+
263+
import VectorIndexerModel._
256264

257265
/** Java-friendly version of [[categoryMaps]] */
258266
def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
@@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] (
408416
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
409417
copyValues(copied, extra).setParent(parent)
410418
}
419+
420+
@Since("1.6.0")
421+
override def write: MLWriter = new VectorIndexerModelWriter(this)
422+
}
423+
424+
@Since("1.6.0")
425+
object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
426+
427+
private[VectorIndexerModel]
428+
class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter {
429+
430+
private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]])
431+
432+
override protected def saveImpl(path: String): Unit = {
433+
DefaultParamsWriter.saveMetadata(instance, path, sc)
434+
val data = Data(instance.numFeatures, instance.categoryMaps)
435+
val dataPath = new Path(path, "data").toString
436+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
437+
}
438+
}
439+
440+
private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] {
441+
442+
private val className = classOf[VectorIndexerModel].getName
443+
444+
override def load(path: String): VectorIndexerModel = {
445+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
446+
val dataPath = new Path(path, "data").toString
447+
val data = sqlContext.read.parquet(dataPath)
448+
.select("numFeatures", "categoryMaps")
449+
.head()
450+
val numFeatures = data.getAs[Int](0)
451+
val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
452+
val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
453+
DefaultParamsReader.getAndSetParams(model, metadata)
454+
model
455+
}
456+
}
457+
458+
@Since("1.6.0")
459+
override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader
460+
461+
@Since("1.6.0")
462+
override def load(path: String): VectorIndexerModel = super.load(path)
411463
}

0 commit comments

Comments
 (0)