1use core::cmp::min;
10
11use num_traits::Zero;
12
13use crate::{
14 dimension::{is_layout_c, is_layout_f},
15 Array,
16 ArrayBase,
17 Axis,
18 Data,
19 Dimension,
20 Zip,
21};
22
23impl<S, A, D> ArrayBase<S, D>
24where
25 S: Data<Elem = A>,
26 D: Dimension,
27 A: Clone + Zero,
28{
29 pub fn triu(&self, k: isize) -> Array<A, D>
55 {
56 if self.ndim() <= 1 {
57 return self.to_owned();
58 }
59
60 let n = self.ndim();
64 if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
65 let mut x = self.view();
66 x.swap_axes(n - 2, n - 1);
67 let mut tril = x.tril(-k);
68 tril.swap_axes(n - 2, n - 1);
69
70 return tril;
71 }
72
73 let mut res = Array::zeros(self.raw_dim());
74 let ncols = self.len_of(Axis(n - 1));
75 let nrows = self.len_of(Axis(n - 2));
76 let indices = Array::from_iter(0..nrows);
77 Zip::from(self.rows())
78 .and(res.rows_mut())
79 .and_broadcast(&indices)
80 .for_each(|src, mut dst, row_num| {
81 let mut lower = match k >= 0 {
82 true => row_num.saturating_add(k as usize), false => row_num.saturating_sub(k.unsigned_abs()), };
85 lower = min(lower, ncols);
86 dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
87 });
88
89 res
90 }
91
92 pub fn tril(&self, k: isize) -> Array<A, D>
118 {
119 if self.ndim() <= 1 {
120 return self.to_owned();
121 }
122
123 let n = self.ndim();
127 if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
128 let mut x = self.view();
129 x.swap_axes(n - 2, n - 1);
130 let mut tril = x.triu(-k);
131 tril.swap_axes(n - 2, n - 1);
132
133 return tril;
134 }
135
136 let mut res = Array::zeros(self.raw_dim());
137 let ncols = self.len_of(Axis(n - 1));
138 let nrows = self.len_of(Axis(n - 2));
139 let indices = Array::from_iter(0..nrows);
140 Zip::from(self.rows())
141 .and(res.rows_mut())
142 .and_broadcast(&indices)
143 .for_each(|src, mut dst, row_num| {
144 let mut upper = match k >= 0 {
146 true => row_num.saturating_add(k as usize).saturating_add(1), false => row_num.saturating_sub((k + 1).unsigned_abs()), };
149 upper = min(upper, ncols);
150 dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
151 });
152
153 res
154 }
155}
156
157#[cfg(test)]
158mod tests
159{
160 use core::isize;
161
162 use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
163 use alloc::vec;
164
165 #[test]
166 fn test_keep_order()
167 {
168 let x = Array2::<f64>::ones((3, 3).f());
169 let res = x.triu(0);
170 assert!(dimension::is_layout_f(&res.dim, &res.strides));
171
172 let res = x.tril(0);
173 assert!(dimension::is_layout_f(&res.dim, &res.strides));
174 }
175
176 #[test]
177 fn test_0d()
178 {
179 let x = Array0::<f64>::ones(());
180 let res = x.triu(0);
181 assert_eq!(res, x);
182
183 let res = x.tril(0);
184 assert_eq!(res, x);
185
186 let x = Array0::<f64>::ones(().f());
187 let res = x.triu(0);
188 assert_eq!(res, x);
189
190 let res = x.tril(0);
191 assert_eq!(res, x);
192 }
193
194 #[test]
195 fn test_1d()
196 {
197 let x = array![1, 2, 3];
198 let res = x.triu(0);
199 assert_eq!(res, x);
200
201 let res = x.triu(0);
202 assert_eq!(res, x);
203
204 let x = Array1::<f64>::ones(3.f());
205 let res = x.triu(0);
206 assert_eq!(res, x);
207
208 let res = x.triu(0);
209 assert_eq!(res, x);
210 }
211
212 #[test]
213 fn test_2d()
214 {
215 let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
216
217 let res = x.triu(0);
219 assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
220
221 let res = x.tril(0);
223 assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
224
225 let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
226
227 let res = x.triu(0);
229 assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
230
231 let res = x.tril(0);
233 assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
234 }
235
236 #[test]
237 fn test_2d_single()
238 {
239 let x = array![[1]];
240
241 assert_eq!(x.triu(0), array![[1]]);
242 assert_eq!(x.tril(0), array![[1]]);
243 assert_eq!(x.triu(1), array![[0]]);
244 assert_eq!(x.tril(1), array![[1]]);
245 assert_eq!(x.triu(-1), array![[1]]);
246 assert_eq!(x.tril(-1), array![[0]]);
247 }
248
249 #[test]
250 fn test_3d()
251 {
252 let x = array![
253 [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
254 [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
255 [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
256 ];
257
258 let res = x.triu(0);
260 assert_eq!(
261 res,
262 array![
263 [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
264 [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
265 [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
266 ]
267 );
268
269 let res = x.tril(0);
271 assert_eq!(
272 res,
273 array![
274 [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
275 [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
276 [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
277 ]
278 );
279
280 let x = Array3::from_shape_vec(
281 (3, 3, 3).f(),
282 vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
283 )
284 .unwrap();
285
286 let res = x.triu(0);
288 assert_eq!(
289 res,
290 array![
291 [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
292 [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
293 [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
294 ]
295 );
296
297 let res = x.tril(0);
299 assert_eq!(
300 res,
301 array![
302 [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
303 [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
304 [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
305 ]
306 );
307 }
308
309 #[test]
310 fn test_off_axis()
311 {
312 let x = array![
313 [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
314 [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
315 [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
316 ];
317
318 let res = x.triu(1);
319 assert_eq!(
320 res,
321 array![
322 [[0, 2, 3], [0, 0, 6], [0, 0, 0]],
323 [[0, 11, 12], [0, 0, 15], [0, 0, 0]],
324 [[0, 20, 21], [0, 0, 24], [0, 0, 0]]
325 ]
326 );
327
328 let res = x.triu(-1);
329 assert_eq!(
330 res,
331 array![
332 [[1, 2, 3], [4, 5, 6], [0, 8, 9]],
333 [[10, 11, 12], [13, 14, 15], [0, 17, 18]],
334 [[19, 20, 21], [22, 23, 24], [0, 26, 27]]
335 ]
336 );
337 }
338
339 #[test]
340 fn test_odd_shape()
341 {
342 let x = array![[1, 2, 3], [4, 5, 6]];
343 let res = x.triu(0);
344 assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
345
346 let res = x.tril(0);
347 assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);
348
349 let x = array![[1, 2], [3, 4], [5, 6]];
350 let res = x.triu(0);
351 assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
352
353 let res = x.tril(0);
354 assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
355 }
356
357 #[test]
358 fn test_odd_k()
359 {
360 let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
361 let z = Array2::zeros([3, 3]);
362 assert_eq!(x.triu(isize::MIN), x);
363 assert_eq!(x.tril(isize::MIN), z);
364 assert_eq!(x.triu(isize::MAX), z);
365 assert_eq!(x.tril(isize::MAX), x);
366 }
367}