Skip to content

Commit 29836e2

Browse files
srowenmengxr
authored andcommitted
[SPARK-10353] [MLLIB] (1.3 backport) BLAS gemm not scaling when beta = 0.0 for some subset of matrix multiplications
Apply fixes for alpha, beta parameter handling in gemm/gemv from apache#8525 to branch 1.3 CC mengxr brkyvz Author: Sean Owen <sowen@cloudera.com> Closes apache#8572 from srowen/SPARK-10353.2.
1 parent a58c1af commit 29836e2

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging {
305305
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
306306
if (alpha == 0.0 && beta == 1.0) {
307307
logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
308+
} else if (alpha == 0.0) {
309+
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
308310
} else {
309311
A match {
310312
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
@@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging {
408410
}
409411
}
410412
} else {
411-
// Scale matrix first if `beta` is not equal to 0.0
412-
if (beta != 0.0) {
413+
// Scale matrix first if `beta` is not equal to 1.0
414+
if (beta != 1.0) {
413415
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
414416
}
415417
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
@@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging {
470472
s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
471473
require(A.numRows == y.size,
472474
s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}")
473-
if (alpha == 0.0) {
474-
logDebug("gemv: alpha is equal to 0. Returning y.")
475+
if (alpha == 0.0 && beta == 1.0) {
476+
logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.")
477+
} else if (alpha == 0.0) {
478+
scal(beta, y)
475479
} else {
476480
A match {
477481
case sparse: SparseMatrix =>
@@ -534,8 +538,8 @@ private[spark] object BLAS extends Serializable with Logging {
534538
rowCounter += 1
535539
}
536540
} else {
537-
// Scale vector first if `beta` is not equal to 0.0
538-
if (beta != 0.0) {
541+
// Scale vector first if `beta` is not equal to 1.0
542+
if (beta != 1.0) {
539543
scal(beta, y)
540544
}
541545
// Perform matrix-vector multiplication and add to y

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite {
204204
val C14 = C1.copy
205205
val C15 = C1.copy
206206
val C16 = C1.copy
207+
val C17 = C1.copy
207208
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
208209
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
209210
val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
@@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite {
217218
assert(C2 ~== expected2 absTol 1e-15)
218219
assert(C3 ~== expected3 absTol 1e-15)
219220
assert(C4 ~== expected3 absTol 1e-15)
221+
gemm(1.0, dA, B, 0.0, C17)
222+
assert(C17 ~== expected absTol 1e-15)
223+
gemm(1.0, sA, B, 0.0, C17)
224+
assert(C17 ~== expected absTol 1e-15)
220225

221226
withClue("columns of A don't match the rows of B") {
222227
intercept[Exception] {

0 commit comments

Comments
 (0)