Skip to content

Commit b20a9ab

Browse files
rotationsymmetrymengxr
authored andcommitted
[SPARK-9175] [MLLIB] BLAS.gemm fails to update matrix C when alpha==0 and beta!=1
Fix BLAS.gemm to update matrix C when alpha==0 and beta!=1 Also include unit tests to verify the fix. mengxr brkyvz Author: Meihua Wu <meihuawu@umich.edu> Closes apache#7503 from rotationsymmetry/fix_BLAS_gemm and squashes the following commits: fce199c [Meihua Wu] Fix BLAS.gemm to update C when alpha==0 and beta!=1 (cherry picked from commit ff3c72d) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent c8b17da commit b20a9ab

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging {
303303
C: DenseMatrix): Unit = {
304304
require(!C.isTransposed,
305305
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
306-
if (alpha == 0.0) {
307-
logDebug("gemm: alpha is equal to 0. Returning C.")
306+
if (alpha == 0.0 && beta == 1.0) {
307+
logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
308308
} else {
309309
A match {
310310
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite {
200200
val C10 = C1.copy
201201
val C11 = C1.copy
202202
val C12 = C1.copy
203+
val C13 = C1.copy
204+
val C14 = C1.copy
205+
val C15 = C1.copy
206+
val C16 = C1.copy
203207
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
204208
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
209+
val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
210+
val expected5 = C1.copy
205211

206212
gemm(1.0, dA, B, 2.0, C1)
207213
gemm(1.0, sA, B, 2.0, C2)
@@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite {
248254
assert(C10 ~== expected2 absTol 1e-15)
249255
assert(C11 ~== expected3 absTol 1e-15)
250256
assert(C12 ~== expected3 absTol 1e-15)
257+
258+
gemm(0, dA, B, 5, C13)
259+
gemm(0, sA, B, 5, C14)
260+
gemm(0, dA, B, 1, C15)
261+
gemm(0, sA, B, 1, C16)
262+
assert(C13 ~== expected4 absTol 1e-15)
263+
assert(C14 ~== expected4 absTol 1e-15)
264+
assert(C15 ~== expected5 absTol 1e-15)
265+
assert(C16 ~== expected5 absTol 1e-15)
266+
251267
}
252268

253269
test("gemv") {

0 commit comments

Comments
 (0)