Skip to content

Commit 2f30927

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib - 1.5 backport
This backports [apache#10161] to Spark 1.5, with the difference that ChiSqSelector does not require modification. Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: yhuai mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes apache#10183 from jkbradley/sqlcontext-backport1.5.
1 parent 3868ab6 commit 2f30927

File tree

12 files changed

+27
-27
lines changed

12 files changed

+27
-27
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11491149
def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
11501150
// We use DataFrames for serialization of IndexedRows to Python,
11511151
// so return a DataFrame.
1152-
val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext)
1152+
val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext)
11531153
sqlContext.createDataFrame(indexedRowMatrix.rows)
11541154
}
11551155

@@ -1159,7 +1159,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11591159
def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
11601160
// We use DataFrames for serialization of MatrixEntry entries to
11611161
// Python, so return a DataFrame.
1162-
val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext)
1162+
val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext)
11631163
sqlContext.createDataFrame(coordinateMatrix.entries)
11641164
}
11651165

@@ -1169,7 +1169,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11691169
def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
11701170
// We use DataFrames for serialization of sub-matrix blocks to
11711171
// Python, so return a DataFrame.
1172-
val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext)
1172+
val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext)
11731173
sqlContext.createDataFrame(blockMatrix.blocks)
11741174
}
11751175
}

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
192192
modelType: String)
193193

194194
def save(sc: SparkContext, path: String, data: Data): Unit = {
195-
val sqlContext = new SQLContext(sc)
195+
val sqlContext = SQLContext.getOrCreate(sc)
196196
import sqlContext.implicits._
197197

198198
// Create JSON metadata.
@@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
208208

209209
@Since("1.3.0")
210210
def load(sc: SparkContext, path: String): NaiveBayesModel = {
211-
val sqlContext = new SQLContext(sc)
211+
val sqlContext = SQLContext.getOrCreate(sc)
212212
// Load Parquet data.
213213
val dataRDD = sqlContext.read.parquet(dataPath(path))
214214
// Check schema explicitly since erasure makes it hard to use match-case for checking.
@@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
239239
theta: Array[Array[Double]])
240240

241241
def save(sc: SparkContext, path: String, data: Data): Unit = {
242-
val sqlContext = new SQLContext(sc)
242+
val sqlContext = SQLContext.getOrCreate(sc)
243243
import sqlContext.implicits._
244244

245245
// Create JSON metadata.
@@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
254254
}
255255

256256
def load(sc: SparkContext, path: String): NaiveBayesModel = {
257-
val sqlContext = new SQLContext(sc)
257+
val sqlContext = SQLContext.getOrCreate(sc)
258258
// Load Parquet data.
259259
val dataRDD = sqlContext.read.parquet(dataPath(path))
260260
// Check schema explicitly since erasure makes it hard to use match-case for checking.

mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
5151
weights: Vector,
5252
intercept: Double,
5353
threshold: Option[Double]): Unit = {
54-
val sqlContext = new SQLContext(sc)
54+
val sqlContext = SQLContext.getOrCreate(sc)
5555
import sqlContext.implicits._
5656

5757
// Create JSON metadata.
@@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel {
7474
*/
7575
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
7676
val datapath = Loader.dataPath(path)
77-
val sqlContext = new SQLContext(sc)
77+
val sqlContext = SQLContext.getOrCreate(sc)
7878
val dataRDD = sqlContext.read.parquet(datapath)
7979
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
8080
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
149149
weights: Array[Double],
150150
gaussians: Array[MultivariateGaussian]): Unit = {
151151

152-
val sqlContext = new SQLContext(sc)
152+
val sqlContext = SQLContext.getOrCreate(sc)
153153
import sqlContext.implicits._
154154

155155
// Create JSON metadata.
@@ -166,7 +166,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
166166

167167
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
168168
val dataPath = Loader.dataPath(path)
169-
val sqlContext = new SQLContext(sc)
169+
val sqlContext = SQLContext.getOrCreate(sc)
170170
val dataFrame = sqlContext.read.parquet(dataPath)
171171
val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
172172

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] {
124124
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
125125

126126
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
127-
val sqlContext = new SQLContext(sc)
127+
val sqlContext = SQLContext.getOrCreate(sc)
128128
import sqlContext.implicits._
129129
val metadata = compact(render(
130130
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
@@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] {
137137

138138
def load(sc: SparkContext, path: String): KMeansModel = {
139139
implicit val formats = DefaultFormats
140-
val sqlContext = new SQLContext(sc)
140+
val sqlContext = SQLContext.getOrCreate(sc)
141141
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
142142
assert(className == thisClassName)
143143
assert(formatVersion == thisFormatVersion)

mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
7373

7474
@Since("1.4.0")
7575
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
76-
val sqlContext = new SQLContext(sc)
76+
val sqlContext = SQLContext.getOrCreate(sc)
7777
import sqlContext.implicits._
7878

7979
val metadata = compact(render(
@@ -87,7 +87,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
8787
@Since("1.4.0")
8888
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
8989
implicit val formats = DefaultFormats
90-
val sqlContext = new SQLContext(sc)
90+
val sqlContext = SQLContext.getOrCreate(sc)
9191

9292
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
9393
assert(className == thisClassName)

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
588588

589589
def load(sc: SparkContext, path: String): Word2VecModel = {
590590
val dataPath = Loader.dataPath(path)
591-
val sqlContext = new SQLContext(sc)
591+
val sqlContext = SQLContext.getOrCreate(sc)
592592
val dataFrame = sqlContext.read.parquet(dataPath)
593593

594594
val dataArray = dataFrame.select("word", "vector").collect()
@@ -602,7 +602,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
602602

603603
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
604604

605-
val sqlContext = new SQLContext(sc)
605+
val sqlContext = SQLContext.getOrCreate(sc)
606606
import sqlContext.implicits._
607607

608608
val vectorSize = model.values.head.size

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
353353
*/
354354
def save(model: MatrixFactorizationModel, path: String): Unit = {
355355
val sc = model.userFeatures.sparkContext
356-
val sqlContext = new SQLContext(sc)
356+
val sqlContext = SQLContext.getOrCreate(sc)
357357
import sqlContext.implicits._
358358
val metadata = compact(render(
359359
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
@@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
364364

365365
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
366366
implicit val formats = DefaultFormats
367-
val sqlContext = new SQLContext(sc)
367+
val sqlContext = SQLContext.getOrCreate(sc)
368368
val (className, formatVersion, metadata) = loadMetadata(sc, path)
369369
assert(className == thisClassName)
370370
assert(formatVersion == thisFormatVersion)

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
188188
boundaries: Array[Double],
189189
predictions: Array[Double],
190190
isotonic: Boolean): Unit = {
191-
val sqlContext = new SQLContext(sc)
191+
val sqlContext = SQLContext.getOrCreate(sc)
192192

193193
val metadata = compact(render(
194194
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -201,7 +201,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
201201
}
202202

203203
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
204-
val sqlContext = new SQLContext(sc)
204+
val sqlContext = SQLContext.getOrCreate(sc)
205205
val dataRDD = sqlContext.read.parquet(dataPath(path))
206206

207207
checkSchema[Data](dataRDD.schema)

mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel {
4747
modelClass: String,
4848
weights: Vector,
4949
intercept: Double): Unit = {
50-
val sqlContext = new SQLContext(sc)
50+
val sqlContext = SQLContext.getOrCreate(sc)
5151
import sqlContext.implicits._
5252

5353
// Create JSON metadata.
@@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel {
7171
*/
7272
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
7373
val datapath = Loader.dataPath(path)
74-
val sqlContext = new SQLContext(sc)
74+
val sqlContext = SQLContext.getOrCreate(sc)
7575
val dataRDD = sqlContext.read.parquet(datapath)
7676
val dataArray = dataRDD.select("weights", "intercept").take(1)
7777
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")

0 commit comments

Comments
 (0)