ndarray/linalg/
impl_linalg.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::imp_prelude::*;
10
11#[cfg(feature = "blas")]
12use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13use crate::numeric_util;
14
15use crate::{LinalgScalar, Zip};
16
17#[cfg(not(feature = "std"))]
18use alloc::vec::Vec;
19use std::any::TypeId;
20use std::mem::MaybeUninit;
21
22use num_complex::Complex;
23use num_complex::{Complex32 as c32, Complex64 as c64};
24
25#[cfg(feature = "blas")]
26use libc::c_int;
27
28#[cfg(feature = "blas")]
29use cblas_sys as blas_sys;
30#[cfg(feature = "blas")]
31use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
32
33/// len of vector before we use blas
34#[cfg(feature = "blas")]
35const DOT_BLAS_CUTOFF: usize = 32;
36/// side of matrix before we use blas
37#[cfg(feature = "blas")]
38const GEMM_BLAS_CUTOFF: usize = 7;
39#[cfg(feature = "blas")]
40#[allow(non_camel_case_types)]
41type blas_index = c_int; // blas index type
42
43impl<A, S> ArrayBase<S, Ix1>
44where S: Data<Elem = A>
45{
46    /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
47    ///
48    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
49    ///
50    /// If `Rhs` is one-dimensional, then the operation is a vector dot
51    /// product, which is the sum of the elementwise products (no conjugation
52    /// of complex operands, and thus not their inner product). In this case,
53    /// `self` and `rhs` must be the same length.
54    ///
55    /// If `Rhs` is two-dimensional, then the operation is matrix
56    /// multiplication, where `self` is treated as a row vector. In this case,
57    /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
58    /// shape *N*.
59    ///
60    /// **Panics** if the array shapes are incompatible.<br>
61    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
62    /// layout allows.
63    #[track_caller]
64    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
65    where Self: Dot<Rhs>
66    {
67        Dot::dot(self, rhs)
68    }
69
70    fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
71    where
72        S2: Data<Elem = A>,
73        A: LinalgScalar,
74    {
75        debug_assert_eq!(self.len(), rhs.len());
76        assert!(self.len() == rhs.len());
77        if let Some(self_s) = self.as_slice() {
78            if let Some(rhs_s) = rhs.as_slice() {
79                return numeric_util::unrolled_dot(self_s, rhs_s);
80            }
81        }
82        let mut sum = A::zero();
83        for i in 0..self.len() {
84            unsafe {
85                sum = sum + *self.uget(i) * *rhs.uget(i);
86            }
87        }
88        sum
89    }
90
91    #[cfg(not(feature = "blas"))]
92    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
93    where
94        S2: Data<Elem = A>,
95        A: LinalgScalar,
96    {
97        self.dot_generic(rhs)
98    }
99
100    #[cfg(feature = "blas")]
101    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
102    where
103        S2: Data<Elem = A>,
104        A: LinalgScalar,
105    {
106        // Use only if the vector is large enough to be worth it
107        if self.len() >= DOT_BLAS_CUTOFF {
108            debug_assert_eq!(self.len(), rhs.len());
109            assert!(self.len() == rhs.len());
110            macro_rules! dot {
111                ($ty:ty, $func:ident) => {{
112                    if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
113                        unsafe {
114                            let (lhs_ptr, n, incx) =
115                                blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
116                            let (rhs_ptr, _, incy) =
117                                blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
118                            let ret = blas_sys::$func(
119                                n,
120                                lhs_ptr as *const $ty,
121                                incx,
122                                rhs_ptr as *const $ty,
123                                incy,
124                            );
125                            return cast_as::<$ty, A>(&ret);
126                        }
127                    }
128                }};
129            }
130
131            dot! {f32, cblas_sdot};
132            dot! {f64, cblas_ddot};
133        }
134        self.dot_generic(rhs)
135    }
136}
137
138/// Return a pointer to the starting element in BLAS's view.
139///
140/// BLAS wants a pointer to the element with lowest address,
141/// which agrees with our pointer for non-negative strides, but
142/// is at the opposite end for negative strides.
143#[cfg(feature = "blas")]
144unsafe fn blas_1d_params<A>(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index)
145{
146    // [x x x x]
147    //        ^--ptr
148    //        stride = -1
149    //  ^--blas_ptr = ptr + (len - 1) * stride
150    if stride >= 0 || len == 0 {
151        (ptr, len as blas_index, stride as blas_index)
152    } else {
153        let ptr = ptr.offset((len - 1) as isize * stride);
154        (ptr, len as blas_index, stride as blas_index)
155    }
156}
157
158/// Matrix Multiplication
159///
160/// For two-dimensional arrays, the dot method computes the matrix
161/// multiplication.
162pub trait Dot<Rhs>
163{
164    /// The result of the operation.
165    ///
166    /// For two-dimensional arrays: a rectangular array.
167    type Output;
168    fn dot(&self, rhs: &Rhs) -> Self::Output;
169}
170
171impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
172where
173    S: Data<Elem = A>,
174    S2: Data<Elem = A>,
175    A: LinalgScalar,
176{
177    type Output = A;
178
179    /// Compute the dot product of one-dimensional arrays.
180    ///
181    /// The dot product is a sum of the elementwise products (no conjugation
182    /// of complex operands, and thus not their inner product).
183    ///
184    /// **Panics** if the arrays are not of the same length.<br>
185    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
186    /// layout allows.
187    #[track_caller]
188    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A
189    {
190        self.dot_impl(rhs)
191    }
192}
193
194impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
195where
196    S: Data<Elem = A>,
197    S2: Data<Elem = A>,
198    A: LinalgScalar,
199{
200    type Output = Array<A, Ix1>;
201
202    /// Perform the matrix multiplication of the row vector `self` and
203    /// rectangular matrix `rhs`.
204    ///
205    /// The array shapes must agree in the way that
206    /// if `self` is *M*, then `rhs` is *M* × *N*.
207    ///
208    /// Return a result array with shape *N*.
209    ///
210    /// **Panics** if shapes are incompatible.
211    #[track_caller]
212    fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1>
213    {
214        rhs.t().dot(self)
215    }
216}
217
218impl<A, S> ArrayBase<S, Ix2>
219where S: Data<Elem = A>
220{
221    /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
222    ///
223    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
224    ///
225    /// If Rhs is two-dimensional, they array shapes must agree in the way that
226    /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
227    ///
228    /// Return a result array with shape *M* × *K*.
229    ///
230    /// **Panics** if shapes are incompatible or the number of elements in the
231    /// result would overflow `isize`.
232    ///
233    /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
234    /// when memory layout allows. The default matrixmultiply backend
235    /// is otherwise used for `f32, f64` for all memory layouts.
236    ///
237    /// ```
238    /// use ndarray::arr2;
239    ///
240    /// let a = arr2(&[[1., 2.],
241    ///                [0., 1.]]);
242    /// let b = arr2(&[[1., 2.],
243    ///                [2., 3.]]);
244    ///
245    /// assert!(
246    ///     a.dot(&b) == arr2(&[[5., 8.],
247    ///                         [2., 3.]])
248    /// );
249    /// ```
250    #[track_caller]
251    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
252    where Self: Dot<Rhs>
253    {
254        Dot::dot(self, rhs)
255    }
256}
257
258impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
259where
260    S: Data<Elem = A>,
261    S2: Data<Elem = A>,
262    A: LinalgScalar,
263{
264    type Output = Array2<A>;
265    fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A>
266    {
267        let a = self.view();
268        let b = b.view();
269        let ((m, k), (k2, n)) = (a.dim(), b.dim());
270        if k != k2 || m.checked_mul(n).is_none() {
271            dot_shape_error(m, k, k2, n);
272        }
273
274        let lhs_s0 = a.strides()[0];
275        let rhs_s0 = b.strides()[0];
276        let column_major = lhs_s0 == 1 && rhs_s0 == 1;
277        // A is Copy so this is safe
278        let mut v = Vec::with_capacity(m * n);
279        let mut c;
280        unsafe {
281            v.set_len(m * n);
282            c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
283        }
284        mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
285        c
286    }
287}
288
289/// Assumes that `m` and `n` are ≤ `isize::MAX`.
290#[cold]
291#[inline(never)]
292fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !
293{
294    match m.checked_mul(n) {
295        Some(len) if len <= isize::MAX as usize => {}
296        _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
297    }
298    panic!(
299        "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
300        m, k, k2, n
301    );
302}
303
304#[cold]
305#[inline(never)]
306fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !
307{
308    panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
309           m, k, k2, n, c1, c2);
310}
311
312/// Perform the matrix multiplication of the rectangular array `self` and
313/// column vector `rhs`.
314///
315/// The array shapes must agree in the way that
316/// if `self` is *M* × *N*, then `rhs` is *N*.
317///
318/// Return a result array with shape *M*.
319///
320/// **Panics** if shapes are incompatible.
321impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
322where
323    S: Data<Elem = A>,
324    S2: Data<Elem = A>,
325    A: LinalgScalar,
326{
327    type Output = Array<A, Ix1>;
328    #[track_caller]
329    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1>
330    {
331        let ((m, a), n) = (self.dim(), rhs.dim());
332        if a != n {
333            dot_shape_error(m, a, n, 1);
334        }
335
336        // Avoid initializing the memory in vec -- set it during iteration
337        unsafe {
338            let mut c = Array1::uninit(m);
339            general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
340            c.assume_init()
341        }
342    }
343}
344
345impl<A, S, D> ArrayBase<S, D>
346where
347    S: Data<Elem = A>,
348    D: Dimension,
349{
350    /// Perform the operation `self += alpha * rhs` efficiently, where
351    /// `alpha` is a scalar and `rhs` is another array. This operation is
352    /// also known as `axpy` in BLAS.
353    ///
354    /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
355    ///
356    /// **Panics** if broadcasting isn’t possible.
357    #[track_caller]
358    pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
359    where
360        S: DataMut,
361        S2: Data<Elem = A>,
362        A: LinalgScalar,
363        E: Dimension,
364    {
365        self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
366    }
367}
368
369// mat_mul_impl uses ArrayView arguments to send all array kinds into
370// the same instantiated implementation.
371#[cfg(not(feature = "blas"))]
372use self::mat_mul_general as mat_mul_impl;
373
374#[cfg(feature = "blas")]
375fn mat_mul_impl<A>(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>)
376where A: LinalgScalar
377{
378    let ((m, k), (k2, n)) = (a.dim(), b.dim());
379    debug_assert_eq!(k, k2);
380    if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
381        && (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
382    {
383        // Compute A B -> C
384        // We require for BLAS compatibility that:
385        // A, B, C are contiguous (stride=1) in their fastest dimension,
386        // but they can be either row major/"c" or col major/"f".
387        //
388        // The "normal case" is CblasRowMajor for cblas.
389        // Select CblasRowMajor / CblasColMajor to fit C's memory order.
390        //
391        // Apply transpose to A, B as needed if they differ from the row major case.
392        // If C is CblasColMajor then transpose both A, B (again!)
393
394        if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
395            (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
396        {
397            let cblas_layout = c_layout.to_cblas_layout();
398            let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
399            let lda = blas_stride(&a, a_layout);
400
401            let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
402            let ldb = blas_stride(&b, b_layout);
403
404            let ldc = blas_stride(&c, c_layout);
405
406            macro_rules! gemm_scalar_cast {
407                (f32, $var:ident) => {
408                    cast_as(&$var)
409                };
410                (f64, $var:ident) => {
411                    cast_as(&$var)
412                };
413                (c32, $var:ident) => {
414                    &$var as *const A as *const _
415                };
416                (c64, $var:ident) => {
417                    &$var as *const A as *const _
418                };
419            }
420
421            macro_rules! gemm {
422                ($ty:tt, $gemm:ident) => {
423                    if same_type::<A, $ty>() {
424                        // gemm is C ← αA^Op B^Op + βC
425                        // Where Op is notrans/trans/conjtrans
426                        unsafe {
427                            blas_sys::$gemm(
428                                cblas_layout,
429                                a_trans,
430                                b_trans,
431                                m as blas_index,                 // m, rows of Op(a)
432                                n as blas_index,                 // n, cols of Op(b)
433                                k as blas_index,                 // k, cols of Op(a)
434                                gemm_scalar_cast!($ty, alpha),   // alpha
435                                a.ptr.as_ptr() as *const _,      // a
436                                lda,                             // lda
437                                b.ptr.as_ptr() as *const _,      // b
438                                ldb,                             // ldb
439                                gemm_scalar_cast!($ty, beta),    // beta
440                                c.ptr.as_ptr() as *mut _,        // c
441                                ldc,                             // ldc
442                            );
443                        }
444                        return;
445                    }
446                };
447            }
448
449            gemm!(f32, cblas_sgemm);
450            gemm!(f64, cblas_dgemm);
451            gemm!(c32, cblas_cgemm);
452            gemm!(c64, cblas_zgemm);
453
454            unreachable!() // we checked above that A is one of f32, f64, c32, c64
455        }
456    }
457    mat_mul_general(alpha, a, b, beta, c)
458}
459
460/// C ← α A B + β C
461fn mat_mul_general<A>(
462    alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>,
463) where A: LinalgScalar
464{
465    let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
466
467    // common parameters for gemm
468    let ap = lhs.as_ptr();
469    let bp = rhs.as_ptr();
470    let cp = c.as_mut_ptr();
471    let (rsc, csc) = (c.strides()[0], c.strides()[1]);
472    if same_type::<A, f32>() {
473        unsafe {
474            matrixmultiply::sgemm(
475                m,
476                k,
477                n,
478                cast_as(&alpha),
479                ap as *const _,
480                lhs.strides()[0],
481                lhs.strides()[1],
482                bp as *const _,
483                rhs.strides()[0],
484                rhs.strides()[1],
485                cast_as(&beta),
486                cp as *mut _,
487                rsc,
488                csc,
489            );
490        }
491    } else if same_type::<A, f64>() {
492        unsafe {
493            matrixmultiply::dgemm(
494                m,
495                k,
496                n,
497                cast_as(&alpha),
498                ap as *const _,
499                lhs.strides()[0],
500                lhs.strides()[1],
501                bp as *const _,
502                rhs.strides()[0],
503                rhs.strides()[1],
504                cast_as(&beta),
505                cp as *mut _,
506                rsc,
507                csc,
508            );
509        }
510    } else if same_type::<A, c32>() {
511        unsafe {
512            matrixmultiply::cgemm(
513                matrixmultiply::CGemmOption::Standard,
514                matrixmultiply::CGemmOption::Standard,
515                m,
516                k,
517                n,
518                complex_array(cast_as(&alpha)),
519                ap as *const _,
520                lhs.strides()[0],
521                lhs.strides()[1],
522                bp as *const _,
523                rhs.strides()[0],
524                rhs.strides()[1],
525                complex_array(cast_as(&beta)),
526                cp as *mut _,
527                rsc,
528                csc,
529            );
530        }
531    } else if same_type::<A, c64>() {
532        unsafe {
533            matrixmultiply::zgemm(
534                matrixmultiply::CGemmOption::Standard,
535                matrixmultiply::CGemmOption::Standard,
536                m,
537                k,
538                n,
539                complex_array(cast_as(&alpha)),
540                ap as *const _,
541                lhs.strides()[0],
542                lhs.strides()[1],
543                bp as *const _,
544                rhs.strides()[0],
545                rhs.strides()[1],
546                complex_array(cast_as(&beta)),
547                cp as *mut _,
548                rsc,
549                csc,
550            );
551        }
552    } else {
553        // It's a no-op if `c` has zero length.
554        if c.is_empty() {
555            return;
556        }
557
558        // initialize memory if beta is zero
559        if beta.is_zero() {
560            c.fill(beta);
561        }
562
563        let mut i = 0;
564        let mut j = 0;
565        loop {
566            unsafe {
567                let elt = c.uget_mut((i, j));
568                *elt =
569                    *elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j)));
570            }
571            j += 1;
572            if j == n {
573                j = 0;
574                i += 1;
575                if i == m {
576                    break;
577                }
578            }
579        }
580    }
581}
582
583/// General matrix-matrix multiplication.
584///
585/// Compute C ← α A B + β C
586///
587/// The array shapes must agree in the way that
588/// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
589///
590/// ***Panics*** if array shapes are not compatible<br>
591/// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
592/// layout allows.  The default matrixmultiply backend is otherwise used for
593/// `f32, f64` for all memory layouts.
594#[track_caller]
595pub fn general_mat_mul<A, S1, S2, S3>(
596    alpha: A, a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>, beta: A, c: &mut ArrayBase<S3, Ix2>,
597) where
598    S1: Data<Elem = A>,
599    S2: Data<Elem = A>,
600    S3: DataMut<Elem = A>,
601    A: LinalgScalar,
602{
603    let ((m, k), (k2, n)) = (a.dim(), b.dim());
604    let (m2, n2) = c.dim();
605    if k != k2 || m != m2 || n != n2 {
606        general_dot_shape_error(m, k, k2, n, m2, n2);
607    } else {
608        mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
609    }
610}
611
612/// General matrix-vector multiplication.
613///
614/// Compute y ← α A x + β y
615///
616/// where A is a *M* × *N* matrix and x is an *N*-element column vector and
617/// y an *M*-element column vector (one dimensional arrays).
618///
619/// ***Panics*** if array shapes are not compatible<br>
620/// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
621/// layout allows.
622#[track_caller]
623#[allow(clippy::collapsible_if)]
624pub fn general_mat_vec_mul<A, S1, S2, S3>(
625    alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: &mut ArrayBase<S3, Ix1>,
626) where
627    S1: Data<Elem = A>,
628    S2: Data<Elem = A>,
629    S3: DataMut<Elem = A>,
630    A: LinalgScalar,
631{
632    unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
633}
634
635/// General matrix-vector multiplication
636///
637/// Use a raw view for the destination vector, so that it can be uninitialized.
638///
639/// ## Safety
640///
641/// The caller must ensure that the raw view is valid for writing.
642/// the destination may be uninitialized iff beta is zero.
643#[allow(clippy::collapsible_else_if)]
644unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
645    alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: RawArrayViewMut<A, Ix1>,
646) where
647    S1: Data<Elem = A>,
648    S2: Data<Elem = A>,
649    A: LinalgScalar,
650{
651    let ((m, k), k2) = (a.dim(), x.dim());
652    let m2 = y.dim();
653    if k != k2 || m != m2 {
654        general_dot_shape_error(m, k, k2, 1, m2, 1);
655    } else {
656        #[cfg(feature = "blas")]
657        macro_rules! gemv {
658            ($ty:ty, $gemv:ident) => {
659                if same_type::<A, $ty>() {
660                    if let Some(layout) = get_blas_compatible_layout(&a) {
661                        if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
662                            // Determine stride between rows or columns. Note that the stride is
663                            // adjusted to at least `k` or `m` to handle the case of a matrix with a
664                            // trivial (length 1) dimension, since the stride for the trivial dimension
665                            // may be arbitrary.
666                            let a_trans = CblasNoTrans;
667
668                            let a_stride = blas_stride(&a, layout);
669                            let cblas_layout = layout.to_cblas_layout();
670
671                            // Low addr in memory pointers required for x, y
672                            let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
673                            let x_ptr = x.ptr.as_ptr().sub(x_offset);
674                            let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
675                            let y_ptr = y.ptr.as_ptr().sub(y_offset);
676
677                            let x_stride = x.strides()[0] as blas_index;
678                            let y_stride = y.strides()[0] as blas_index;
679
680                            blas_sys::$gemv(
681                                cblas_layout,
682                                a_trans,
683                                m as blas_index,            // m, rows of Op(a)
684                                k as blas_index,            // n, cols of Op(a)
685                                cast_as(&alpha),            // alpha
686                                a.ptr.as_ptr() as *const _, // a
687                                a_stride,                   // lda
688                                x_ptr as *const _,          // x
689                                x_stride,
690                                cast_as(&beta),             // beta
691                                y_ptr as *mut _,            // y
692                                y_stride,
693                            );
694                            return;
695                        }
696                    }
697                }
698            };
699        }
700        #[cfg(feature = "blas")]
701        gemv!(f32, cblas_sgemv);
702        #[cfg(feature = "blas")]
703        gemv!(f64, cblas_dgemv);
704
705        /* general */
706
707        if beta.is_zero() {
708            // when beta is zero, c may be uninitialized
709            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
710                elt.write(row.dot(x) * alpha);
711            });
712        } else {
713            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
714                *elt = *elt * beta + row.dot(x) * alpha;
715            });
716        }
717    }
718}
719
720/// Kronecker product of 2D matrices.
721///
722/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
723/// matrix K formed by the block multiplication A_ij * B.
724pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
725where
726    S1: Data<Elem = A>,
727    S2: Data<Elem = A>,
728    A: LinalgScalar,
729{
730    let dimar = a.shape()[0];
731    let dimac = a.shape()[1];
732    let dimbr = b.shape()[0];
733    let dimbc = b.shape()[1];
734    let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
735        dimar
736            .checked_mul(dimbr)
737            .expect("Dimensions of kronecker product output array overflows usize."),
738        dimac
739            .checked_mul(dimbc)
740            .expect("Dimensions of kronecker product output array overflows usize."),
741    ));
742    Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
743        .and(a)
744        .for_each(|out, &a| {
745            Zip::from(out).and(b).for_each(|out, &b| {
746                *out = MaybeUninit::new(a * b);
747            })
748        });
749    unsafe { out.assume_init() }
750}
751
752#[inline(always)]
753/// Return `true` if `A` and `B` are the same type
754fn same_type<A: 'static, B: 'static>() -> bool
755{
756    TypeId::of::<A>() == TypeId::of::<B>()
757}
758
759// Read pointer to type `A` as type `B`.
760//
761// **Panics** if `A` and `B` are not the same type
762fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B
763{
764    assert!(same_type::<A, B>(), "expect type {} and {} to match",
765            std::any::type_name::<A>(), std::any::type_name::<B>());
766    unsafe { ::std::ptr::read(a as *const _ as *const B) }
767}
768
769/// Return the complex in the form of an array [re, im]
770#[inline]
771fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2]
772{
773    [z.re, z.im]
774}
775
776#[cfg(feature = "blas")]
777fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
778where
779    S: RawData,
780    A: 'static,
781    S::Elem: 'static,
782{
783    if !same_type::<A, S::Elem>() {
784        return false;
785    }
786    if a.len() > blas_index::MAX as usize {
787        return false;
788    }
789    let stride = a.strides()[0];
790    if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize {
791        return false;
792    }
793    true
794}
795
796#[cfg(feature = "blas")]
797#[derive(Copy, Clone)]
798#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
799enum BlasOrder
800{
801    C,
802    F,
803}
804
805#[cfg(feature = "blas")]
806impl BlasOrder
807{
808    fn transpose(self) -> Self
809    {
810        match self {
811            Self::C => Self::F,
812            Self::F => Self::C,
813        }
814    }
815
816    #[inline]
817    /// Axis of leading stride (opposite of contiguous axis)
818    fn get_blas_lead_axis(self) -> usize
819    {
820        match self {
821            Self::C => 0,
822            Self::F => 1,
823        }
824    }
825
826    fn to_cblas_layout(self) -> CBLAS_LAYOUT
827    {
828        match self {
829            Self::C => CBLAS_LAYOUT::CblasRowMajor,
830            Self::F => CBLAS_LAYOUT::CblasColMajor,
831        }
832    }
833
834    /// When using cblas_sgemm (etc) with C matrix using `for_layout`,
835    /// how should this `self` matrix be transposed
836    fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
837    {
838        let effective_order = match for_layout {
839            CBLAS_LAYOUT::CblasRowMajor => self,
840            CBLAS_LAYOUT::CblasColMajor => self.transpose(),
841        };
842
843        match effective_order {
844            Self::C => CblasNoTrans,
845            Self::F => CblasTrans,
846        }
847    }
848}
849
850#[cfg(feature = "blas")]
851fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
852{
853    let (m, n) = dim.into_pattern();
854    let s0 = stride[0] as isize;
855    let s1 = stride[1] as isize;
856    let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
857        BlasOrder::C => (s1, s0, m, n),
858        BlasOrder::F => (s0, s1, n, m),
859    };
860
861    if !(inner_stride == 1 || outer_dim == 1) {
862        return false;
863    }
864
865    if s0 < 1 || s1 < 1 {
866        return false;
867    }
868
869    if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
870        || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
871    {
872        return false;
873    }
874
875    // leading stride must >= the dimension (no broadcasting/aliasing)
876    if inner_dim > 1 && (outer_stride as usize) < outer_dim {
877        return false;
878    }
879
880    if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
881        return false;
882    }
883
884    true
885}
886
887/// Get BLAS compatible layout if any (C or F, preferring the former)
888#[cfg(feature = "blas")]
889fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<BlasOrder>
890where S: Data
891{
892    if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) {
893        Some(BlasOrder::C)
894    } else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) {
895        Some(BlasOrder::F)
896    } else {
897        None
898    }
899}
900
901/// `a` should be blas compatible.
902/// axis: 0 or 1.
903///
904/// Return leading stride (lda, ldb, ldc) of array
905#[cfg(feature = "blas")]
906fn blas_stride<S>(a: &ArrayBase<S, Ix2>, order: BlasOrder) -> blas_index
907where S: Data
908{
909    let axis = order.get_blas_lead_axis();
910    let other_axis = 1 - axis;
911    let len_this = a.shape()[axis];
912    let len_other = a.shape()[other_axis];
913    let stride = a.strides()[axis];
914
915    // if current axis has length == 1, then stride does not matter for ndarray
916    // but for BLAS we need a stride that makes sense, i.e. it's >= the other axis
917
918    // cast: a should already be blas compatible
919    (if len_this <= 1 {
920        Ord::max(stride, len_other as isize)
921    } else {
922        stride
923    }) as blas_index
924}
925
926#[cfg(test)]
927#[cfg(feature = "blas")]
928fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
929where
930    S: Data,
931    A: 'static,
932    S::Elem: 'static,
933{
934    if !same_type::<A, S::Elem>() {
935        return false;
936    }
937    is_blas_2d(&a.dim, &a.strides, BlasOrder::C)
938}
939
940#[cfg(test)]
941#[cfg(feature = "blas")]
942fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
943where
944    S: Data,
945    A: 'static,
946    S::Elem: 'static,
947{
948    if !same_type::<A, S::Elem>() {
949        return false;
950    }
951    is_blas_2d(&a.dim, &a.strides, BlasOrder::F)
952}
953
954#[cfg(test)]
955#[cfg(feature = "blas")]
956mod blas_tests
957{
958    use super::*;
959
960    #[test]
961    fn blas_row_major_2d_normal_matrix()
962    {
963        let m: Array2<f32> = Array2::zeros((3, 5));
964        assert!(blas_row_major_2d::<f32, _>(&m));
965        assert!(!blas_column_major_2d::<f32, _>(&m));
966    }
967
968    #[test]
969    fn blas_row_major_2d_row_matrix()
970    {
971        let m: Array2<f32> = Array2::zeros((1, 5));
972        assert!(blas_row_major_2d::<f32, _>(&m));
973        assert!(blas_column_major_2d::<f32, _>(&m));
974    }
975
976    #[test]
977    fn blas_row_major_2d_column_matrix()
978    {
979        let m: Array2<f32> = Array2::zeros((5, 1));
980        assert!(blas_row_major_2d::<f32, _>(&m));
981        assert!(blas_column_major_2d::<f32, _>(&m));
982    }
983
984    #[test]
985    fn blas_row_major_2d_transposed_row_matrix()
986    {
987        let m: Array2<f32> = Array2::zeros((1, 5));
988        let m_t = m.t();
989        assert!(blas_row_major_2d::<f32, _>(&m_t));
990        assert!(blas_column_major_2d::<f32, _>(&m_t));
991    }
992
993    #[test]
994    fn blas_row_major_2d_transposed_column_matrix()
995    {
996        let m: Array2<f32> = Array2::zeros((5, 1));
997        let m_t = m.t();
998        assert!(blas_row_major_2d::<f32, _>(&m_t));
999        assert!(blas_column_major_2d::<f32, _>(&m_t));
1000    }
1001
1002    #[test]
1003    fn blas_column_major_2d_normal_matrix()
1004    {
1005        let m: Array2<f32> = Array2::zeros((3, 5).f());
1006        assert!(!blas_row_major_2d::<f32, _>(&m));
1007        assert!(blas_column_major_2d::<f32, _>(&m));
1008    }
1009
1010    #[test]
1011    fn blas_row_major_2d_skip_rows_ok()
1012    {
1013        let m: Array2<f32> = Array2::zeros((5, 5));
1014        let mv = m.slice(s![..;2, ..]);
1015        assert!(blas_row_major_2d::<f32, _>(&mv));
1016        assert!(!blas_column_major_2d::<f32, _>(&mv));
1017    }
1018
1019    #[test]
1020    fn blas_row_major_2d_skip_columns_fail()
1021    {
1022        let m: Array2<f32> = Array2::zeros((5, 5));
1023        let mv = m.slice(s![.., ..;2]);
1024        assert!(!blas_row_major_2d::<f32, _>(&mv));
1025        assert!(!blas_column_major_2d::<f32, _>(&mv));
1026    }
1027
1028    #[test]
1029    fn blas_col_major_2d_skip_columns_ok()
1030    {
1031        let m: Array2<f32> = Array2::zeros((5, 5).f());
1032        let mv = m.slice(s![.., ..;2]);
1033        assert!(blas_column_major_2d::<f32, _>(&mv));
1034        assert!(!blas_row_major_2d::<f32, _>(&mv));
1035    }
1036
1037    #[test]
1038    fn blas_col_major_2d_skip_rows_fail()
1039    {
1040        let m: Array2<f32> = Array2::zeros((5, 5).f());
1041        let mv = m.slice(s![..;2, ..]);
1042        assert!(!blas_column_major_2d::<f32, _>(&mv));
1043        assert!(!blas_row_major_2d::<f32, _>(&mv));
1044    }
1045
1046    #[test]
1047    fn blas_too_short_stride()
1048    {
1049        // leading stride must be longer than the other dimension
1050        // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1051
1052        const N: usize = 5;
1053        const MAXSTRIDE: usize = N + 2;
1054        let mut data = [0; MAXSTRIDE * N];
1055        let mut iter = 0..data.len();
1056        data.fill_with(|| iter.next().unwrap());
1057
1058        for stride in 1..=MAXSTRIDE {
1059            let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1060            eprintln!("{:?}", m);
1061
1062            if stride < N {
1063                assert_eq!(get_blas_compatible_layout(&m), None);
1064            } else {
1065                assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
1066            }
1067        }
1068    }
1069}