@@ -126,8 +126,8 @@ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha,
126
126
floating beta, floating * y, int incy) noexcept nogil:
127
127
""" y := alpha * op(A).x + beta * y"""
128
128
cdef char ta_ = ta
129
- if order == RowMajor:
130
- ta_ = NoTrans if ta == Trans else Trans
129
+ if order == BLAS_Order. RowMajor:
130
+ ta_ = BLAS_Trans. NoTrans if ta == BLAS_Trans. Trans else BLAS_Trans. Trans
131
131
if floating is float :
132
132
sgemv(& ta_, & n, & m, & alpha, < float * > A, & lda, < float * > x,
133
133
& incx, & beta, y, & incy)
@@ -148,8 +148,10 @@ cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A,
148
148
cdef:
149
149
int m = A.shape[0 ]
150
150
int n = A.shape[1 ]
151
- BLAS_Order order = ColMajor if A.strides[0 ] == A.itemsize else RowMajor
152
- int lda = m if order == ColMajor else n
151
+ BLAS_Order order = (
152
+ BLAS_Order.ColMajor if A.strides[0 ] == A.itemsize else BLAS_Order.RowMajor
153
+ )
154
+ int lda = m if order == BLAS_Order.ColMajor else n
153
155
154
156
_gemv(order, ta, m, n, alpha, & A[0 , 0 ], lda, & x[0 ], 1 , beta, & y[0 ], 1 )
155
157
@@ -158,7 +160,7 @@ cdef void _ger(BLAS_Order order, int m, int n, floating alpha,
158
160
const floating * x, int incx, const floating * y,
159
161
int incy, floating * A, int lda) noexcept nogil:
160
162
""" A := alpha * x.y.T + A"""
161
- if order == RowMajor:
163
+ if order == BLAS_Order. RowMajor:
162
164
if floating is float :
163
165
sger(& n, & m, & alpha, < float * > y, & incy, < float * > x, & incx, A, & lda)
164
166
else :
@@ -175,8 +177,10 @@ cpdef _ger_memview(floating alpha, const floating[::1] x,
175
177
cdef:
176
178
int m = A.shape[0 ]
177
179
int n = A.shape[1 ]
178
- BLAS_Order order = ColMajor if A.strides[0 ] == A.itemsize else RowMajor
179
- int lda = m if order == ColMajor else n
180
+ BLAS_Order order = (
181
+ BLAS_Order.ColMajor if A.strides[0 ] == A.itemsize else BLAS_Order.RowMajor
182
+ )
183
+ int lda = m if order == BLAS_Order.ColMajor else n
180
184
181
185
_ger(order, m, n, alpha, & x[0 ], 1 , & y[0 ], 1 , & A[0 , 0 ], lda)
182
186
@@ -194,7 +198,7 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n,
194
198
cdef:
195
199
char ta_ = ta
196
200
char tb_ = tb
197
- if order == RowMajor:
201
+ if order == BLAS_Order. RowMajor:
198
202
if floating is float :
199
203
sgemm(& tb_, & ta_, & n, & m, & k, & alpha, < float * > B,
200
204
& ldb, < float * > A, & lda, & beta, C, & ldc)
@@ -214,19 +218,21 @@ cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha,
214
218
const floating[:, :] A, const floating[:, :] B, floating beta,
215
219
floating[:, :] C):
216
220
cdef:
217
- int m = A.shape[0 ] if ta == NoTrans else A.shape[1 ]
218
- int n = B.shape[1 ] if tb == NoTrans else B.shape[0 ]
219
- int k = A.shape[1 ] if ta == NoTrans else A.shape[0 ]
221
+ int m = A.shape[0 ] if ta == BLAS_Trans. NoTrans else A.shape[1 ]
222
+ int n = B.shape[1 ] if tb == BLAS_Trans. NoTrans else B.shape[0 ]
223
+ int k = A.shape[1 ] if ta == BLAS_Trans. NoTrans else A.shape[0 ]
220
224
int lda, ldb, ldc
221
- BLAS_Order order = ColMajor if A.strides[0 ] == A.itemsize else RowMajor
225
+ BLAS_Order order = (
226
+ BLAS_Order.ColMajor if A.strides[0 ] == A.itemsize else BLAS_Order.RowMajor
227
+ )
222
228
223
- if order == RowMajor:
224
- lda = k if ta == NoTrans else m
225
- ldb = n if tb == NoTrans else k
229
+ if order == BLAS_Order. RowMajor:
230
+ lda = k if ta == BLAS_Trans. NoTrans else m
231
+ ldb = n if tb == BLAS_Trans. NoTrans else k
226
232
ldc = n
227
233
else :
228
- lda = m if ta == NoTrans else k
229
- ldb = k if tb == NoTrans else n
234
+ lda = m if ta == BLAS_Trans. NoTrans else k
235
+ ldb = k if tb == BLAS_Trans. NoTrans else n
230
236
ldc = m
231
237
232
238
_gemm(order, ta, tb, m, n, k, alpha, & A[0 , 0 ],
0 commit comments