Skip to content

Commit 0df5304

Browse files
lee19mengxr
authored andcommitted
[SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
I'm sorry that I made apache#6949 closed by mistake. I pushed codes again. And, I added a test code. > There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()` It should have been `U.numCols() = k = svd.U.numCols()` > ``` self = U * sigma * V.transpose (m x n) = (m x n) * (k x k) * (k x n) //ASIS --> (m x n) = (m x k) * (k x k) * (k x n) //TOBE ``` Author: lee19 <lee19@live.co.kr> Closes apache#6953 from lee19/MLlibBugfix and squashes the following commits: c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden. 4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error. c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib] 8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k (cherry picked from commit e725262) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 24c2c58 commit 0df5304

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class IndexedRowMatrix(
108108
val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
109109
IndexedRow(i, v)
110110
}
111-
new IndexedRowMatrix(indexedRows, nRows, nCols)
111+
new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
112112
} else {
113113
null
114114
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
113113
assert(closeToZero(U * brzDiag(s) * V.t - localA))
114114
}
115115

116+
test("validate matrix sizes of svd") {
117+
val k = 2
118+
val A = new IndexedRowMatrix(indexedRows)
119+
val svd = A.computeSVD(k, computeU = true)
120+
assert(svd.U.numRows() === m)
121+
assert(svd.U.numCols() === k)
122+
assert(svd.s.size === k)
123+
assert(svd.V.numRows === n)
124+
assert(svd.V.numCols === k)
125+
}
126+
116127
def closeToZero(G: BDM[Double]): Boolean = {
117128
G.valuesIterator.map(math.abs).sum < 1e-6
118129
}

0 commit comments

Comments
 (0)