Skip to content

Commit 3d64376

Browse files
authored
MNT Use BLAS_Order.ColMajor sklearn/utils/_cython_blas.pyx (scikit-learn#31263)
1 parent b98dc79 commit 3d64376

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

azure-pipelines.yml

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ jobs:
8888
DISTRIB: 'conda-free-threaded'
8989
LOCK_FILE: './build_tools/azure/pylatest_free_threaded_linux-64_conda.lock'
9090
COVERAGE: 'false'
91+
SKLEARN_FAULTHANDLER_TIMEOUT: '1800' # 30 * 60 seconds
9192

9293
# Will run all the time regardless of linting outcome.
9394
- template: build_tools/azure/posix.yml

sklearn/conftest.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import builtins
5+
import faulthandler
56
import platform
67
import sys
78
from contextlib import suppress
@@ -341,6 +342,11 @@ def pytest_configure(config):
341342
for line in get_pytest_filterwarning_lines():
342343
config.addinivalue_line("filterwarnings", line)
343344

345+
faulthandler_timeout = int(environ.get("SKLEARN_FAULTHANDLER_TIMEOUT", "0"))
346+
if faulthandler_timeout > 0:
347+
faulthandler.enable()
348+
faulthandler.dump_traceback_later(faulthandler_timeout, exit=True)
349+
344350

345351
@pytest.fixture
346352
def hide_available_pandas(monkeypatch):

sklearn/utils/_cython_blas.pyx

+23-17
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha,
126126
floating beta, floating *y, int incy) noexcept nogil:
127127
"""y := alpha * op(A).x + beta * y"""
128128
cdef char ta_ = ta
129-
if order == RowMajor:
130-
ta_ = NoTrans if ta == Trans else Trans
129+
if order == BLAS_Order.RowMajor:
130+
ta_ = BLAS_Trans.NoTrans if ta == BLAS_Trans.Trans else BLAS_Trans.Trans
131131
if floating is float:
132132
sgemv(&ta_, &n, &m, &alpha, <float *> A, &lda, <float *> x,
133133
&incx, &beta, y, &incy)
@@ -148,8 +148,10 @@ cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A,
148148
cdef:
149149
int m = A.shape[0]
150150
int n = A.shape[1]
151-
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
152-
int lda = m if order == ColMajor else n
151+
BLAS_Order order = (
152+
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
153+
)
154+
int lda = m if order == BLAS_Order.ColMajor else n
153155

154156
_gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1)
155157

@@ -158,7 +160,7 @@ cdef void _ger(BLAS_Order order, int m, int n, floating alpha,
158160
const floating *x, int incx, const floating *y,
159161
int incy, floating *A, int lda) noexcept nogil:
160162
"""A := alpha * x.y.T + A"""
161-
if order == RowMajor:
163+
if order == BLAS_Order.RowMajor:
162164
if floating is float:
163165
sger(&n, &m, &alpha, <float *> y, &incy, <float *> x, &incx, A, &lda)
164166
else:
@@ -175,8 +177,10 @@ cpdef _ger_memview(floating alpha, const floating[::1] x,
175177
cdef:
176178
int m = A.shape[0]
177179
int n = A.shape[1]
178-
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
179-
int lda = m if order == ColMajor else n
180+
BLAS_Order order = (
181+
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
182+
)
183+
int lda = m if order == BLAS_Order.ColMajor else n
180184

181185
_ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
182186

@@ -194,7 +198,7 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n,
194198
cdef:
195199
char ta_ = ta
196200
char tb_ = tb
197-
if order == RowMajor:
201+
if order == BLAS_Order.RowMajor:
198202
if floating is float:
199203
sgemm(&tb_, &ta_, &n, &m, &k, &alpha, <float*>B,
200204
&ldb, <float*>A, &lda, &beta, C, &ldc)
@@ -214,19 +218,21 @@ cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha,
214218
const floating[:, :] A, const floating[:, :] B, floating beta,
215219
floating[:, :] C):
216220
cdef:
217-
int m = A.shape[0] if ta == NoTrans else A.shape[1]
218-
int n = B.shape[1] if tb == NoTrans else B.shape[0]
219-
int k = A.shape[1] if ta == NoTrans else A.shape[0]
221+
int m = A.shape[0] if ta == BLAS_Trans.NoTrans else A.shape[1]
222+
int n = B.shape[1] if tb == BLAS_Trans.NoTrans else B.shape[0]
223+
int k = A.shape[1] if ta == BLAS_Trans.NoTrans else A.shape[0]
220224
int lda, ldb, ldc
221-
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
225+
BLAS_Order order = (
226+
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
227+
)
222228

223-
if order == RowMajor:
224-
lda = k if ta == NoTrans else m
225-
ldb = n if tb == NoTrans else k
229+
if order == BLAS_Order.RowMajor:
230+
lda = k if ta == BLAS_Trans.NoTrans else m
231+
ldb = n if tb == BLAS_Trans.NoTrans else k
226232
ldc = n
227233
else:
228-
lda = m if ta == NoTrans else k
229-
ldb = k if tb == NoTrans else n
234+
lda = m if ta == BLAS_Trans.NoTrans else k
235+
ldb = k if tb == BLAS_Trans.NoTrans else n
230236
ldc = m
231237

232238
_gemm(order, ta, tb, m, n, k, alpha, &A[0, 0],

0 commit comments

Comments
 (0)