Skip to content

Commit fdfac22

Browse files
NarineKrxin
authored andcommitted
[SPARK-12509][SQL] Fixed error messages for DataFrame correlation and covariance
Currently, when we call corr or cov on dataframe with invalid input we see these error messages for both corr and cov: - "Currently cov supports calculating the covariance between two columns" - "Covariance calculation for columns with dataType "[DataType Name]" not supported." I've fixed this issue by passing the function name as an argument. We could also do the input checks separately for each function. I avoided doing that because of code duplication. Thanks! Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com> Closes apache#10458 from NarineK/sparksqlstatsmessages.
1 parent 34de24a commit fdfac22

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging {
2929

3030
/** Calculate the Pearson Correlation Coefficient for the given columns */
3131
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
32-
val counts = collectStatisticalData(df, cols)
32+
val counts = collectStatisticalData(df, cols, "correlation")
3333
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
3434
}
3535

@@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging {
7373
def cov: Double = Ck / (count - 1)
7474
}
7575

76-
private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = {
77-
require(cols.length == 2, "Currently cov supports calculating the covariance " +
76+
private def collectStatisticalData(df: DataFrame, cols: Seq[String],
77+
functionName: String): CovarianceCounter = {
78+
require(cols.length == 2, s"Currently $functionName calculation is supported " +
7879
"between two columns.")
7980
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
8081
require(data.nonEmpty, s"Couldn't find column with name $name")
81-
require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " +
82-
s"with dataType ${data.get.dataType} not supported.")
82+
require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " +
83+
s"for columns with dataType ${data.get.dataType} not supported.")
8384
}
8485
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
8586
df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)(
@@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging {
9899
* @return the covariance of the two columns.
99100
*/
100101
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
101-
val counts = collectStatisticalData(df, cols)
102+
val counts = collectStatisticalData(df, cols, "covariance")
102103
counts.cov
103104
}
104105

0 commit comments

Comments
 (0)