Skip to content

Commit 4774897

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-11846] Add save/load for AFTSurvivalRegression and IsotonicRegression
https://issues.apache.org/jira/browse/SPARK-11846 mengxr Author: Xusen Yin <yinxusen@gmail.com> Closes apache#9836 from yinxusen/SPARK-11846. (cherry picked from commit 4114ce2) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent a936fa5 commit 4774897

File tree

4 files changed

+210
-22
lines changed

4 files changed

+210
-22
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ import scala.collection.mutable
2121

2222
import breeze.linalg.{DenseVector => BDV}
2323
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
24+
import org.apache.hadoop.fs.Path
2425

25-
import org.apache.spark.{SparkException, Logging}
26-
import org.apache.spark.annotation.{Since, Experimental}
27-
import org.apache.spark.ml.{Model, Estimator}
26+
import org.apache.spark.annotation.{Experimental, Since}
2827
import org.apache.spark.ml.param._
2928
import org.apache.spark.ml.param.shared._
30-
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
31-
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
32-
import org.apache.spark.mllib.linalg.BLAS
29+
import org.apache.spark.ml.util._
30+
import org.apache.spark.ml.{Estimator, Model}
31+
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
3332
import org.apache.spark.rdd.RDD
34-
import org.apache.spark.sql.{Row, DataFrame}
3533
import org.apache.spark.sql.functions._
3634
import org.apache.spark.sql.types.{DoubleType, StructType}
35+
import org.apache.spark.sql.{DataFrame, Row}
3736
import org.apache.spark.storage.StorageLevel
37+
import org.apache.spark.{Logging, SparkException}
3838

3939
/**
4040
* Params for accelerated failure time (AFT) regression.
@@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
120120
@Experimental
121121
@Since("1.6.0")
122122
class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String)
123-
extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging {
123+
extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
124+
with DefaultParamsWritable with Logging {
124125

125126
@Since("1.6.0")
126127
def this() = this(Identifiable.randomUID("aftSurvReg"))
@@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
243244
override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra)
244245
}
245246

247+
@Since("1.6.0")
248+
object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] {
249+
250+
@Since("1.6.0")
251+
override def load(path: String): AFTSurvivalRegression = super.load(path)
252+
}
253+
246254
/**
247255
* :: Experimental ::
248256
* Model produced by [[AFTSurvivalRegression]].
@@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] (
254262
@Since("1.6.0") val coefficients: Vector,
255263
@Since("1.6.0") val intercept: Double,
256264
@Since("1.6.0") val scale: Double)
257-
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams {
265+
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {
258266

259267
/** @group setParam */
260268
@Since("1.6.0")
@@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] (
312320
copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra)
313321
.setParent(parent)
314322
}
323+
324+
@Since("1.6.0")
325+
override def write: MLWriter =
326+
new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)
327+
}
328+
329+
@Since("1.6.0")
330+
object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] {
331+
332+
@Since("1.6.0")
333+
override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader
334+
335+
@Since("1.6.0")
336+
override def load(path: String): AFTSurvivalRegressionModel = super.load(path)
337+
338+
/** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */
339+
private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter (
340+
instance: AFTSurvivalRegressionModel
341+
) extends MLWriter with Logging {
342+
343+
private case class Data(coefficients: Vector, intercept: Double, scale: Double)
344+
345+
override protected def saveImpl(path: String): Unit = {
346+
// Save metadata and Params
347+
DefaultParamsWriter.saveMetadata(instance, path, sc)
348+
// Save model data: coefficients, intercept, scale
349+
val data = Data(instance.coefficients, instance.intercept, instance.scale)
350+
val dataPath = new Path(path, "data").toString
351+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
352+
}
353+
}
354+
355+
private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] {
356+
357+
/** Checked against metadata when loading model */
358+
private val className = classOf[AFTSurvivalRegressionModel].getName
359+
360+
override def load(path: String): AFTSurvivalRegressionModel = {
361+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
362+
363+
val dataPath = new Path(path, "data").toString
364+
val data = sqlContext.read.parquet(dataPath)
365+
.select("coefficients", "intercept", "scale").head()
366+
val coefficients = data.getAs[Vector](0)
367+
val intercept = data.getDouble(1)
368+
val scale = data.getDouble(2)
369+
val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
370+
371+
DefaultParamsReader.getAndSetParams(model, metadata)
372+
model
373+
}
374+
}
315375
}
316376

317377
/**

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,22 @@
1717

1818
package org.apache.spark.ml.regression
1919

20+
import org.apache.hadoop.fs.Path
21+
2022
import org.apache.spark.Logging
2123
import org.apache.spark.annotation.{Experimental, Since}
22-
import org.apache.spark.ml.{Estimator, Model}
2324
import org.apache.spark.ml.param._
24-
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol}
25-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
25+
import org.apache.spark.ml.param.shared._
26+
import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter
27+
import org.apache.spark.ml.util._
28+
import org.apache.spark.ml.{Estimator, Model}
2629
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
27-
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel}
30+
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
31+
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
2832
import org.apache.spark.rdd.RDD
29-
import org.apache.spark.sql.{DataFrame, Row}
3033
import org.apache.spark.sql.functions.{col, lit, udf}
3134
import org.apache.spark.sql.types.{DoubleType, StructType}
35+
import org.apache.spark.sql.{DataFrame, Row}
3236
import org.apache.spark.storage.StorageLevel
3337

3438
/**
@@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
127131
@Since("1.5.0")
128132
@Experimental
129133
class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String)
130-
extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase {
134+
extends Estimator[IsotonicRegressionModel]
135+
with IsotonicRegressionBase with DefaultParamsWritable {
131136

132137
@Since("1.5.0")
133138
def this() = this(Identifiable.randomUID("isoReg"))
@@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
179184
}
180185
}
181186

187+
@Since("1.6.0")
188+
object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] {
189+
190+
@Since("1.6.0")
191+
override def load(path: String): IsotonicRegression = super.load(path)
192+
}
193+
182194
/**
183195
* :: Experimental ::
184196
* Model fitted by IsotonicRegression.
@@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
194206
class IsotonicRegressionModel private[ml] (
195207
override val uid: String,
196208
private val oldModel: MLlibIsotonicRegressionModel)
197-
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase {
209+
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable {
198210

199211
/** @group setParam */
200212
@Since("1.5.0")
@@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] (
240252
override def transformSchema(schema: StructType): StructType = {
241253
validateAndTransformSchema(schema, fitting = false)
242254
}
255+
256+
@Since("1.6.0")
257+
override def write: MLWriter =
258+
new IsotonicRegressionModelWriter(this)
259+
}
260+
261+
@Since("1.6.0")
262+
object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {
263+
264+
@Since("1.6.0")
265+
override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader
266+
267+
@Since("1.6.0")
268+
override def load(path: String): IsotonicRegressionModel = super.load(path)
269+
270+
/** [[MLWriter]] instance for [[IsotonicRegressionModel]] */
271+
private[IsotonicRegressionModel] class IsotonicRegressionModelWriter (
272+
instance: IsotonicRegressionModel
273+
) extends MLWriter with Logging {
274+
275+
private case class Data(
276+
boundaries: Array[Double],
277+
predictions: Array[Double],
278+
isotonic: Boolean)
279+
280+
override protected def saveImpl(path: String): Unit = {
281+
// Save metadata and Params
282+
DefaultParamsWriter.saveMetadata(instance, path, sc)
283+
// Save model data: boundaries, predictions, isotonic
284+
val data = Data(
285+
instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic)
286+
val dataPath = new Path(path, "data").toString
287+
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
288+
}
289+
}
290+
291+
private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] {
292+
293+
/** Checked against metadata when loading model */
294+
private val className = classOf[IsotonicRegressionModel].getName
295+
296+
override def load(path: String): IsotonicRegressionModel = {
297+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
298+
299+
val dataPath = new Path(path, "data").toString
300+
val data = sqlContext.read.parquet(dataPath)
301+
.select("boundaries", "predictions", "isotonic").head()
302+
val boundaries = data.getAs[Seq[Double]](0).toArray
303+
val predictions = data.getAs[Seq[Double]](1).toArray
304+
val isotonic = data.getBoolean(2)
305+
val model = new IsotonicRegressionModel(
306+
metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic))
307+
308+
DefaultParamsReader.getAndSetParams(model, metadata)
309+
model
310+
}
311+
}
243312
}

mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ import scala.util.Random
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.ml.param.ParamsSuite
24-
import org.apache.spark.ml.util.MLTestingUtils
24+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2525
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2626
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
27-
import org.apache.spark.mllib.util.TestingUtils._
2827
import org.apache.spark.mllib.util.MLlibTestSparkContext
29-
import org.apache.spark.sql.{Row, DataFrame}
28+
import org.apache.spark.mllib.util.TestingUtils._
29+
import org.apache.spark.sql.{DataFrame, Row}
3030

31-
class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
31+
class AFTSurvivalRegressionSuite
32+
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3233

3334
@transient var datasetUnivariate: DataFrame = _
3435
@transient var datasetMultivariate: DataFrame = _
@@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
332333
assert(prediction ~== model.predict(features) relTol 1E-5)
333334
}
334335
}
336+
337+
test("read/write") {
338+
def checkModelData(
339+
model: AFTSurvivalRegressionModel,
340+
model2: AFTSurvivalRegressionModel): Unit = {
341+
assert(model.intercept === model2.intercept)
342+
assert(model.coefficients === model2.coefficients)
343+
assert(model.scale === model2.scale)
344+
}
345+
val aft = new AFTSurvivalRegression()
346+
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
347+
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
348+
}
349+
}
350+
351+
object AFTSurvivalRegressionSuite {
352+
353+
/**
354+
* Mapping from all Params to valid settings which differ from the defaults.
355+
* This is useful for tests which need to exercise all Params, such as save/load.
356+
* This excludes input columns to simplify some tests.
357+
*/
358+
val allParamSettings: Map[String, Any] = Map(
359+
"predictionCol" -> "myPrediction",
360+
"fitIntercept" -> true,
361+
"maxIter" -> 2,
362+
"tol" -> 0.01
363+
)
335364
}

mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ package org.apache.spark.ml.regression
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.param.ParamsSuite
22-
import org.apache.spark.ml.util.MLTestingUtils
22+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2323
import org.apache.spark.mllib.linalg.Vectors
2424
import org.apache.spark.mllib.util.MLlibTestSparkContext
2525
import org.apache.spark.sql.{DataFrame, Row}
2626

27-
class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
27+
class IsotonicRegressionSuite
28+
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
29+
2830
private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
2931
sqlContext.createDataFrame(
3032
labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
@@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
164166

165167
assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
166168
}
169+
170+
test("read/write") {
171+
val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))
172+
173+
def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = {
174+
assert(model.boundaries === model2.boundaries)
175+
assert(model.predictions === model2.predictions)
176+
assert(model.isotonic === model2.isotonic)
177+
}
178+
179+
val ir = new IsotonicRegression()
180+
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
181+
checkModelData)
182+
}
183+
}
184+
185+
object IsotonicRegressionSuite {
186+
187+
/**
188+
* Mapping from all Params to valid settings which differ from the defaults.
189+
* This is useful for tests which need to exercise all Params, such as save/load.
190+
* This excludes input columns to simplify some tests.
191+
*/
192+
val allParamSettings: Map[String, Any] = Map(
193+
"predictionCol" -> "myPrediction",
194+
"isotonic" -> true,
195+
"featureIndex" -> 0
196+
)
167197
}

0 commit comments

Comments
 (0)