ndarray/
tri.rs

1// Copyright 2014-2024 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use core::cmp::min;
10
11use num_traits::Zero;
12
13use crate::{
14    dimension::{is_layout_c, is_layout_f},
15    Array,
16    ArrayBase,
17    Axis,
18    Data,
19    Dimension,
20    Zip,
21};
22
23impl<S, A, D> ArrayBase<S, D>
24where
25    S: Data<Elem = A>,
26    D: Dimension,
27    A: Clone + Zero,
28{
29    /// Upper triangular of an array.
30    ///
31    /// Return a copy of the array with elements below the *k*-th diagonal zeroed.
32    /// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes.
33    /// For 0D and 1D arrays, `triu` will return an unchanged clone.
34    ///
35    /// See also [`ArrayBase::tril`]
36    ///
37    /// ```
38    /// use ndarray::array;
39    ///
40    /// let arr = array![
41    ///     [1, 2, 3],
42    ///     [4, 5, 6],
43    ///     [7, 8, 9]
44    /// ];
45    /// assert_eq!(
46    ///     arr.triu(0),
47    ///     array![
48    ///         [1, 2, 3],
49    ///         [0, 5, 6],
50    ///         [0, 0, 9]
51    ///     ]
52    /// );
53    /// ```
54    pub fn triu(&self, k: isize) -> Array<A, D>
55    {
56        if self.ndim() <= 1 {
57            return self.to_owned();
58        }
59
60        // Performance optimization for F-order arrays.
61        // C-order array check prevents infinite recursion in edge cases like [[1]].
62        // k-size check prevents underflow when k == isize::MIN
63        let n = self.ndim();
64        if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
65            let mut x = self.view();
66            x.swap_axes(n - 2, n - 1);
67            let mut tril = x.tril(-k);
68            tril.swap_axes(n - 2, n - 1);
69
70            return tril;
71        }
72
73        let mut res = Array::zeros(self.raw_dim());
74        let ncols = self.len_of(Axis(n - 1));
75        let nrows = self.len_of(Axis(n - 2));
76        let indices = Array::from_iter(0..nrows);
77        Zip::from(self.rows())
78            .and(res.rows_mut())
79            .and_broadcast(&indices)
80            .for_each(|src, mut dst, row_num| {
81                let mut lower = match k >= 0 {
82                    true => row_num.saturating_add(k as usize),        // Avoid overflow
83                    false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0
84                };
85                lower = min(lower, ncols);
86                dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
87            });
88
89        res
90    }
91
92    /// Lower triangular of an array.
93    ///
94    /// Return a copy of the array with elements above the *k*-th diagonal zeroed.
95    /// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes.
96    /// For 0D and 1D arrays, `tril` will return an unchanged clone.
97    ///
98    /// See also [`ArrayBase::triu`]
99    ///
100    /// ```
101    /// use ndarray::array;
102    ///
103    /// let arr = array![
104    ///     [1, 2, 3],
105    ///     [4, 5, 6],
106    ///     [7, 8, 9]
107    /// ];
108    /// assert_eq!(
109    ///     arr.tril(0),
110    ///     array![
111    ///         [1, 0, 0],
112    ///         [4, 5, 0],
113    ///         [7, 8, 9]
114    ///     ]
115    /// );
116    /// ```
117    pub fn tril(&self, k: isize) -> Array<A, D>
118    {
119        if self.ndim() <= 1 {
120            return self.to_owned();
121        }
122
123        // Performance optimization for F-order arrays.
124        // C-order array check prevents infinite recursion in edge cases like [[1]].
125        // k-size check prevents underflow when k == isize::MIN
126        let n = self.ndim();
127        if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
128            let mut x = self.view();
129            x.swap_axes(n - 2, n - 1);
130            let mut tril = x.triu(-k);
131            tril.swap_axes(n - 2, n - 1);
132
133            return tril;
134        }
135
136        let mut res = Array::zeros(self.raw_dim());
137        let ncols = self.len_of(Axis(n - 1));
138        let nrows = self.len_of(Axis(n - 2));
139        let indices = Array::from_iter(0..nrows);
140        Zip::from(self.rows())
141            .and(res.rows_mut())
142            .and_broadcast(&indices)
143            .for_each(|src, mut dst, row_num| {
144                // let row_num = i.into_dimension().last_elem();
145                let mut upper = match k >= 0 {
146                    true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow
147                    false => row_num.saturating_sub((k + 1).unsigned_abs()),      // Avoid underflow
148                };
149                upper = min(upper, ncols);
150                dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
151            });
152
153        res
154    }
155}
156
157#[cfg(test)]
158mod tests
159{
160    use core::isize;
161
162    use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
163    use alloc::vec;
164
165    #[test]
166    fn test_keep_order()
167    {
168        let x = Array2::<f64>::ones((3, 3).f());
169        let res = x.triu(0);
170        assert!(dimension::is_layout_f(&res.dim, &res.strides));
171
172        let res = x.tril(0);
173        assert!(dimension::is_layout_f(&res.dim, &res.strides));
174    }
175
176    #[test]
177    fn test_0d()
178    {
179        let x = Array0::<f64>::ones(());
180        let res = x.triu(0);
181        assert_eq!(res, x);
182
183        let res = x.tril(0);
184        assert_eq!(res, x);
185
186        let x = Array0::<f64>::ones(().f());
187        let res = x.triu(0);
188        assert_eq!(res, x);
189
190        let res = x.tril(0);
191        assert_eq!(res, x);
192    }
193
194    #[test]
195    fn test_1d()
196    {
197        let x = array![1, 2, 3];
198        let res = x.triu(0);
199        assert_eq!(res, x);
200
201        let res = x.triu(0);
202        assert_eq!(res, x);
203
204        let x = Array1::<f64>::ones(3.f());
205        let res = x.triu(0);
206        assert_eq!(res, x);
207
208        let res = x.triu(0);
209        assert_eq!(res, x);
210    }
211
212    #[test]
213    fn test_2d()
214    {
215        let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
216
217        // Upper
218        let res = x.triu(0);
219        assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
220
221        // Lower
222        let res = x.tril(0);
223        assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
224
225        let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
226
227        // Upper
228        let res = x.triu(0);
229        assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
230
231        // Lower
232        let res = x.tril(0);
233        assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
234    }
235
236    #[test]
237    fn test_2d_single()
238    {
239        let x = array![[1]];
240
241        assert_eq!(x.triu(0), array![[1]]);
242        assert_eq!(x.tril(0), array![[1]]);
243        assert_eq!(x.triu(1), array![[0]]);
244        assert_eq!(x.tril(1), array![[1]]);
245        assert_eq!(x.triu(-1), array![[1]]);
246        assert_eq!(x.tril(-1), array![[0]]);
247    }
248
249    #[test]
250    fn test_3d()
251    {
252        let x = array![
253            [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
254            [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
255            [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
256        ];
257
258        // Upper
259        let res = x.triu(0);
260        assert_eq!(
261            res,
262            array![
263                [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
264                [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
265                [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
266            ]
267        );
268
269        // Lower
270        let res = x.tril(0);
271        assert_eq!(
272            res,
273            array![
274                [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
275                [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
276                [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
277            ]
278        );
279
280        let x = Array3::from_shape_vec(
281            (3, 3, 3).f(),
282            vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
283        )
284        .unwrap();
285
286        // Upper
287        let res = x.triu(0);
288        assert_eq!(
289            res,
290            array![
291                [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
292                [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
293                [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
294            ]
295        );
296
297        // Lower
298        let res = x.tril(0);
299        assert_eq!(
300            res,
301            array![
302                [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
303                [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
304                [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
305            ]
306        );
307    }
308
309    #[test]
310    fn test_off_axis()
311    {
312        let x = array![
313            [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
314            [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
315            [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
316        ];
317
318        let res = x.triu(1);
319        assert_eq!(
320            res,
321            array![
322                [[0, 2, 3], [0, 0, 6], [0, 0, 0]],
323                [[0, 11, 12], [0, 0, 15], [0, 0, 0]],
324                [[0, 20, 21], [0, 0, 24], [0, 0, 0]]
325            ]
326        );
327
328        let res = x.triu(-1);
329        assert_eq!(
330            res,
331            array![
332                [[1, 2, 3], [4, 5, 6], [0, 8, 9]],
333                [[10, 11, 12], [13, 14, 15], [0, 17, 18]],
334                [[19, 20, 21], [22, 23, 24], [0, 26, 27]]
335            ]
336        );
337    }
338
339    #[test]
340    fn test_odd_shape()
341    {
342        let x = array![[1, 2, 3], [4, 5, 6]];
343        let res = x.triu(0);
344        assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
345
346        let res = x.tril(0);
347        assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);
348
349        let x = array![[1, 2], [3, 4], [5, 6]];
350        let res = x.triu(0);
351        assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
352
353        let res = x.tril(0);
354        assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
355    }
356
357    #[test]
358    fn test_odd_k()
359    {
360        let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
361        let z = Array2::zeros([3, 3]);
362        assert_eq!(x.triu(isize::MIN), x);
363        assert_eq!(x.tril(isize::MIN), z);
364        assert_eq!(x.triu(isize::MAX), z);
365        assert_eq!(x.tril(isize::MAX), x);
366    }
367}