ndarray/zip/mod.rs
1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::partial::Partial;
18use crate::AssignElem;
19use crate::IntoDimension;
20use crate::Layout;
21
22use crate::dimension;
23use crate::indexes::{indices, Indices};
24use crate::split_at::{SplitAt, SplitPreference};
25
26pub use self::ndproducer::{IntoNdProducer, NdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30 ($e:expr) => {
31 match $e {
32 FoldWhile::Continue(x) => x,
33 x => return x,
34 }
35 };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting](ArrayBase#broadcasting) for more information.
41trait Broadcast<E>
42where E: IntoDimension
43{
44 type Output: NdProducer<Dim = E::Dim>;
45 /// Broadcast the array to the new dimensions `shape`.
46 ///
47 /// ***Panics*** if broadcasting isn’t possible.
48 #[track_caller]
49 fn broadcast_unwrap(self, shape: E) -> Self::Output;
50 private_decl! {}
51}
52
53/// Compute `Layout` hints for array shape dim, strides
54fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout
55{
56 let n = dim.ndim();
57 if dimension::is_layout_c(dim, strides) {
58 // effectively one-dimensional => C and F layout compatible
59 if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
60 Layout::one_dimensional()
61 } else {
62 Layout::c()
63 }
64 } else if n > 1 && dimension::is_layout_f(dim, strides) {
65 Layout::f()
66 } else if n > 1 {
67 if dim[0] > 1 && strides[0] == 1 {
68 Layout::fpref()
69 } else if dim[n - 1] > 1 && strides[n - 1] == 1 {
70 Layout::cpref()
71 } else {
72 Layout::none()
73 }
74 } else {
75 Layout::none()
76 }
77}
78
79impl<S, D> ArrayBase<S, D>
80where
81 S: RawData,
82 D: Dimension,
83{
84 pub(crate) fn layout_impl(&self) -> Layout
85 {
86 array_layout(&self.dim, &self.strides)
87 }
88}
89
90impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
91where
92 E: IntoDimension,
93 D: Dimension,
94{
95 type Output = ArrayView<'a, A, E::Dim>;
96 fn broadcast_unwrap(self, shape: E) -> Self::Output
97 {
98 #[allow(clippy::needless_borrow)]
99 let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
100 unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
101 }
102 private_impl! {}
103}
104
105trait ZippableTuple: Sized
106{
107 type Item;
108 type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
109 type Dim: Dimension;
110 type Stride: Copy;
111 fn as_ptr(&self) -> Self::Ptr;
112 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
113 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
114 fn stride_of(&self, index: usize) -> Self::Stride;
115 fn contiguous_stride(&self) -> Self::Stride;
116 fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
117}
118
119/// Lock step function application across several arrays or other producers.
120///
121/// Zip allows matching several producers to each other elementwise and applying
122/// a function over all tuples of elements (one item from each input at
123/// a time).
124///
125/// In general, the zip uses a tuple of producers
126/// ([`NdProducer`] trait) that all have to be of the
127/// same shape. The NdProducer implementation defines what its item type is
128/// (for example if it's a shared reference, mutable reference or an array
129/// view etc).
130///
131/// If all the input arrays are of the same memory layout the zip performs much
132/// better and the compiler can usually vectorize the loop (if applicable).
133///
134/// The order elements are visited is not specified. The producers don’t have to
135/// have the same item type.
136///
137/// The `Zip` has two methods for function application: `for_each` and
138/// `fold_while`. The zip object can be split, which allows parallelization.
139/// A read-only zip object (no mutable producers) can be cloned.
140///
141/// See also the [`azip!()`] which offers a convenient shorthand
142/// to common ways to use `Zip`.
143///
144/// ```
145/// use ndarray::Zip;
146/// use ndarray::Array2;
147///
148/// type M = Array2<f64>;
149///
150/// // Create four 2d arrays of the same size
151/// let mut a = M::zeros((64, 32));
152/// let b = M::from_elem(a.dim(), 1.);
153/// let c = M::from_elem(a.dim(), 2.);
154/// let d = M::from_elem(a.dim(), 3.);
155///
156/// // Example 1: Perform an elementwise arithmetic operation across
157/// // the four arrays a, b, c, d.
158///
159/// Zip::from(&mut a)
160/// .and(&b)
161/// .and(&c)
162/// .and(&d)
163/// .for_each(|w, &x, &y, &z| {
164/// *w += x + y * z;
165/// });
166///
167/// // Example 2: Create a new array `totals` with one entry per row of `a`.
168/// // Use Zip to traverse the rows of `a` and assign to the corresponding
169/// // entry in `totals` with the sum across each row.
170/// // This is possible because the producer for `totals` and the row producer
171/// // for `a` have the same shape and dimensionality.
172/// // The rows producer yields one array view (`row`) per iteration.
173///
174/// use ndarray::{Array1, Axis};
175///
176/// let mut totals = Array1::zeros(a.nrows());
177///
178/// Zip::from(&mut totals)
179/// .and(a.rows())
180/// .for_each(|totals, row| *totals = row.sum());
181///
182/// // Check the result against the built in `.sum_axis()` along axis 1.
183/// assert_eq!(totals, a.sum_axis(Axis(1)));
184///
185///
186/// // Example 3: Recreate Example 2 using map_collect to make a new array
187///
188/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
189///
190/// // Check the result against the previous example.
191/// assert_eq!(totals, totals2);
192/// ```
193#[derive(Debug, Clone)]
194#[must_use = "zipping producers is lazy and does nothing unless consumed"]
195pub struct Zip<Parts, D>
196{
197 parts: Parts,
198 dimension: D,
199 layout: Layout,
200 /// The sum of the layout tendencies of the parts;
201 /// positive for c- and negative for f-layout preference.
202 layout_tendency: i32,
203}
204
205impl<P, D> Zip<(P,), D>
206where
207 D: Dimension,
208 P: NdProducer<Dim = D>,
209{
210 /// Create a new `Zip` from the input array or other producer `p`.
211 ///
212 /// The Zip will take the exact dimension of `p` and all inputs
213 /// must have the same dimensions (or be broadcast to them).
214 pub fn from<IP>(p: IP) -> Self
215 where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
216 {
217 let array = p.into_producer();
218 let dim = array.raw_dim();
219 let layout = array.layout();
220 Zip {
221 dimension: dim,
222 layout,
223 parts: (array,),
224 layout_tendency: layout.tendency(),
225 }
226 }
227}
228impl<P, D> Zip<(Indices<D>, P), D>
229where
230 D: Dimension + Copy,
231 P: NdProducer<Dim = D>,
232{
233 /// Create a new `Zip` with an index producer and the producer `p`.
234 ///
235 /// The Zip will take the exact dimension of `p` and all inputs
236 /// must have the same dimensions (or be broadcast to them).
237 ///
238 /// *Note:* Indexed zip has overhead.
239 pub fn indexed<IP>(p: IP) -> Self
240 where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
241 {
242 let array = p.into_producer();
243 let dim = array.raw_dim();
244 Zip::from(indices(dim)).and(array)
245 }
246}
247
248#[inline]
249fn zip_dimension_check<D, P>(dimension: &D, part: &P)
250where
251 D: Dimension,
252 P: NdProducer<Dim = D>,
253{
254 ndassert!(
255 part.equal_dim(dimension),
256 "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
257 dimension,
258 part.raw_dim()
259 );
260}
261
262impl<Parts, D> Zip<Parts, D>
263where D: Dimension
264{
265 /// Return a the number of element tuples in the Zip
266 pub fn size(&self) -> usize
267 {
268 self.dimension.size()
269 }
270
271 /// Return the length of `axis`
272 ///
273 /// ***Panics*** if `axis` is out of bounds.
274 #[track_caller]
275 fn len_of(&self, axis: Axis) -> usize
276 {
277 self.dimension[axis.index()]
278 }
279
280 fn prefer_f(&self) -> bool
281 {
282 !self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
283 }
284
285 /// Return an *approximation* to the max stride axis; if
286 /// component arrays disagree, there may be no choice better than the
287 /// others.
288 fn max_stride_axis(&self) -> Axis
289 {
290 let i = if self.prefer_f() {
291 self.dimension
292 .slice()
293 .iter()
294 .rposition(|&len| len > 1)
295 .unwrap_or(self.dimension.ndim() - 1)
296 } else {
297 /* corder or default */
298 self.dimension
299 .slice()
300 .iter()
301 .position(|&len| len > 1)
302 .unwrap_or(0)
303 };
304 Axis(i)
305 }
306}
307
308impl<P, D> Zip<P, D>
309where D: Dimension
310{
311 fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
312 where
313 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
314 P: ZippableTuple<Dim = D>,
315 {
316 if self.dimension.ndim() == 0 {
317 function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
318 } else if self.layout.is(Layout::CORDER | Layout::FORDER) {
319 self.for_each_core_contiguous(acc, function)
320 } else {
321 self.for_each_core_strided(acc, function)
322 }
323 }
324
325 fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
326 where
327 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
328 P: ZippableTuple<Dim = D>,
329 {
330 debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
331 let size = self.dimension.size();
332 let ptrs = self.parts.as_ptr();
333 let inner_strides = self.parts.contiguous_stride();
334 unsafe { self.inner(acc, ptrs, inner_strides, size, &mut function) }
335 }
336
337 /// The innermost loop of the Zip for_each methods
338 ///
339 /// Run the fold while operation on a stretch of elements with constant strides
340 ///
341 /// `ptr`: base pointer for the first element in this stretch
342 /// `strides`: strides for the elements in this stretch
343 /// `len`: number of elements
344 /// `function`: closure
345 unsafe fn inner<F, Acc>(
346 &self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, len: usize, function: &mut F,
347 ) -> FoldWhile<Acc>
348 where
349 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
350 P: ZippableTuple,
351 {
352 let mut i = 0;
353 while i < len {
354 let p = ptr.stride_offset(strides, i);
355 acc = fold_while!(function(acc, self.parts.as_ref(p)));
356 i += 1;
357 }
358 FoldWhile::Continue(acc)
359 }
360
361 fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
362 where
363 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
364 P: ZippableTuple<Dim = D>,
365 {
366 let n = self.dimension.ndim();
367 if n == 0 {
368 panic!("Unreachable: ndim == 0 is contiguous")
369 }
370 if n == 1 || self.layout_tendency >= 0 {
371 self.for_each_core_strided_c(acc, function)
372 } else {
373 self.for_each_core_strided_f(acc, function)
374 }
375 }
376
377 // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
378 fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
379 where
380 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
381 P: ZippableTuple<Dim = D>,
382 {
383 let n = self.dimension.ndim();
384 let unroll_axis = n - 1;
385 let inner_len = self.dimension[unroll_axis];
386 self.dimension[unroll_axis] = 1;
387 let mut index_ = self.dimension.first_index();
388 let inner_strides = self.parts.stride_of(unroll_axis);
389 // Loop unrolled over closest axis
390 while let Some(index) = index_ {
391 unsafe {
392 let ptr = self.parts.uget_ptr(&index);
393 acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
394 }
395
396 index_ = self.dimension.next_for(index);
397 }
398 FoldWhile::Continue(acc)
399 }
400
401 // Non-contiguous but preference for F - unroll over Axis(0)
402 fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
403 where
404 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
405 P: ZippableTuple<Dim = D>,
406 {
407 let unroll_axis = 0;
408 let inner_len = self.dimension[unroll_axis];
409 self.dimension[unroll_axis] = 1;
410 let index_ = self.dimension.first_index();
411 let inner_strides = self.parts.stride_of(unroll_axis);
412 // Loop unrolled over closest axis
413 if let Some(mut index) = index_ {
414 loop {
415 unsafe {
416 let ptr = self.parts.uget_ptr(&index);
417 acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
418 }
419
420 if !self.dimension.next_for_f(&mut index) {
421 break;
422 }
423 }
424 }
425 FoldWhile::Continue(acc)
426 }
427
428 #[cfg(feature = "rayon")]
429 pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
430 {
431 let is_f = self.prefer_f();
432 Array::uninit(self.dimension.clone().set_f(is_f))
433 }
434}
435
436impl<D, P1, P2> Zip<(P1, P2), D>
437where
438 D: Dimension,
439 P1: NdProducer<Dim = D>,
440 P1: NdProducer<Dim = D>,
441{
442 /// Debug assert traversal order is like c (including 1D case)
443 // Method placement: only used for binary Zip at the moment.
444 #[inline]
445 pub(crate) fn debug_assert_c_order(self) -> Self
446 {
447 debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
448 self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
449 "Assertion failed: traversal is not c-order or 1D for \
450 layout {:?}, tendency {}, dimension {:?}",
451 self.layout, self.layout_tendency, self.dimension);
452 self
453 }
454}
455
456/*
457trait Offset : Copy {
458 unsafe fn offset(self, off: isize) -> Self;
459 unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
460 self.offset(index as isize * stride)
461 }
462}
463
464impl<T> Offset for *mut T {
465 unsafe fn offset(self, off: isize) -> Self {
466 self.offset(off)
467 }
468}
469*/
470
471trait OffsetTuple
472{
473 type Args;
474 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
475}
476
477impl<T> OffsetTuple for *mut T
478{
479 type Args = isize;
480 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self
481 {
482 self.offset(index as isize * stride)
483 }
484}
485
486macro_rules! offset_impl {
487 ($([$($param:ident)*][ $($q:ident)*],)+) => {
488 $(
489 #[allow(non_snake_case)]
490 impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
491 type Args = ($($param::Stride,)*);
492 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
493 let ($($param, )*) = self;
494 let ($($q, )*) = stride;
495 ($(Offset::stride_offset($param, $q, index),)*)
496 }
497 }
498 )+
499 };
500}
501
502offset_impl! {
503 [A ][ a],
504 [A B][ a b],
505 [A B C][ a b c],
506 [A B C D][ a b c d],
507 [A B C D E][ a b c d e],
508 [A B C D E F][ a b c d e f],
509}
510
511macro_rules! zipt_impl {
512 ($([$($p:ident)*][ $($q:ident)*],)+) => {
513 $(
514 #[allow(non_snake_case)]
515 impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
516 type Item = ($($p::Item, )*);
517 type Ptr = ($($p::Ptr, )*);
518 type Dim = Dim;
519 type Stride = ($($p::Stride,)* );
520
521 fn stride_of(&self, index: usize) -> Self::Stride {
522 let ($(ref $p,)*) = *self;
523 ($($p.stride_of(Axis(index)), )*)
524 }
525
526 fn contiguous_stride(&self) -> Self::Stride {
527 let ($(ref $p,)*) = *self;
528 ($($p.contiguous_stride(), )*)
529 }
530
531 fn as_ptr(&self) -> Self::Ptr {
532 let ($(ref $p,)*) = *self;
533 ($($p.as_ptr(), )*)
534 }
535 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
536 let ($(ref $q ,)*) = *self;
537 let ($($p,)*) = ptr;
538 ($($q.as_ref($p),)*)
539 }
540
541 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
542 let ($(ref $p,)*) = *self;
543 ($($p.uget_ptr(i), )*)
544 }
545
546 fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
547 let ($($p,)*) = self;
548 let ($($p,)*) = (
549 $($p.split_at(axis, index), )*
550 );
551 (
552 ($($p.0,)*),
553 ($($p.1,)*)
554 )
555 }
556 }
557 )+
558 };
559}
560
561zipt_impl! {
562 [A ][ a],
563 [A B][ a b],
564 [A B C][ a b c],
565 [A B C D][ a b c d],
566 [A B C D E][ a b c d e],
567 [A B C D E F][ a b c d e f],
568}
569
570macro_rules! map_impl {
571 ($([$notlast:ident $($p:ident)*],)+) => {
572 $(
573 #[allow(non_snake_case)]
574 impl<D, $($p),*> Zip<($($p,)*), D>
575 where D: Dimension,
576 $($p: NdProducer<Dim=D> ,)*
577 {
578 /// Apply a function to all elements of the input arrays,
579 /// visiting elements in lock step.
580 pub fn for_each<F>(mut self, mut function: F)
581 where F: FnMut($($p::Item),*)
582 {
583 self.for_each_core((), move |(), args| {
584 let ($($p,)*) = args;
585 FoldWhile::Continue(function($($p),*))
586 });
587 }
588
589 /// Apply a fold function to all elements of the input arrays,
590 /// visiting elements in lock step.
591 ///
592 /// # Example
593 ///
594 /// The expression `tr(AᵀB)` can be more efficiently computed as
595 /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
596 /// elements of the entry-wise product). It would be possible to
597 /// evaluate this expression by first computing the entry-wise
598 /// product, `A∘B`, and then computing the elementwise sum of that
599 /// product, but it's possible to do this in a single loop (and
600 /// avoid an extra heap allocation if `A` and `B` can't be
601 /// consumed) by using `Zip`:
602 ///
603 /// ```
604 /// use ndarray::{array, Zip};
605 ///
606 /// let a = array![[1, 5], [3, 7]];
607 /// let b = array![[2, 4], [8, 6]];
608 ///
609 /// // Without using `Zip`. This involves two loops and an extra
610 /// // heap allocation for the result of `&a * &b`.
611 /// let sum_prod_nonzip = (&a * &b).sum();
612 /// // Using `Zip`. This is a single loop without any heap allocations.
613 /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
614 ///
615 /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
616 /// ```
617 pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
618 where
619 F: FnMut(Acc, $($p::Item),*) -> Acc,
620 {
621 self.for_each_core(acc, move |acc, args| {
622 let ($($p,)*) = args;
623 FoldWhile::Continue(function(acc, $($p),*))
624 }).into_inner()
625 }
626
627 /// Apply a fold function to the input arrays while the return
628 /// value is `FoldWhile::Continue`, visiting elements in lock step.
629 ///
630 pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
631 -> FoldWhile<Acc>
632 where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
633 {
634 self.for_each_core(acc, move |acc, args| {
635 let ($($p,)*) = args;
636 function(acc, $($p),*)
637 })
638 }
639
640 /// Tests if every element of the iterator matches a predicate.
641 ///
642 /// Returns `true` if `predicate` evaluates to `true` for all elements.
643 /// Returns `true` if the input arrays are empty.
644 ///
645 /// Example:
646 ///
647 /// ```
648 /// use ndarray::{array, Zip};
649 /// let a = array![1, 2, 3];
650 /// let b = array![1, 4, 9];
651 /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
652 /// ```
653 pub fn all<F>(mut self, mut predicate: F) -> bool
654 where F: FnMut($($p::Item),*) -> bool
655 {
656 !self.for_each_core((), move |_, args| {
657 let ($($p,)*) = args;
658 if predicate($($p),*) {
659 FoldWhile::Continue(())
660 } else {
661 FoldWhile::Done(())
662 }
663 }).is_done()
664 }
665
666 /// Tests if at least one element of the iterator matches a predicate.
667 ///
668 /// Returns `true` if `predicate` evaluates to `true` for at least one element.
669 /// Returns `false` if the input arrays are empty.
670 ///
671 /// Example:
672 ///
673 /// ```
674 /// use ndarray::{array, Zip};
675 /// let a = array![1, 2, 3];
676 /// let b = array![1, 4, 9];
677 /// assert!(Zip::from(&a).and(&b).any(|&a, &b| a == b));
678 /// assert!(!Zip::from(&a).and(&b).any(|&a, &b| a - 1 == b));
679 /// ```
680 pub fn any<F>(mut self, mut predicate: F) -> bool
681 where F: FnMut($($p::Item),*) -> bool
682 {
683 self.for_each_core((), move |_, args| {
684 let ($($p,)*) = args;
685 if predicate($($p),*) {
686 FoldWhile::Done(())
687 } else {
688 FoldWhile::Continue(())
689 }
690 }).is_done()
691 }
692
693 expand_if!(@bool [$notlast]
694
695 /// Include the producer `p` in the Zip.
696 ///
697 /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
698 #[track_caller]
699 pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
700 where P: IntoNdProducer<Dim=D>,
701 {
702 let part = p.into_producer();
703 zip_dimension_check(&self.dimension, &part);
704 self.build_and(part)
705 }
706
707 /// Include the producer `p` in the Zip.
708 ///
709 /// ## Safety
710 ///
711 /// The caller must ensure that the producer's shape is equal to the Zip's shape.
712 /// Uses assertions when debug assertions are enabled.
713 #[allow(unused)]
714 pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
715 where P: IntoNdProducer<Dim=D>,
716 {
717 #[cfg(debug_assertions)]
718 {
719 self.and(p)
720 }
721 #[cfg(not(debug_assertions))]
722 {
723 self.build_and(p.into_producer())
724 }
725 }
726
727 /// Include the producer `p` in the Zip, broadcasting if needed.
728 ///
729 /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
730 ///
731 /// ***Panics*** if broadcasting isn’t possible.
732 #[track_caller]
733 pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
734 -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
735 where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
736 D2: Dimension,
737 {
738 let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
739 self.build_and(part)
740 }
741
742 fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
743 where P: NdProducer<Dim=D>,
744 {
745 let part_layout = part.layout();
746 let ($($p,)*) = self.parts;
747 Zip {
748 parts: ($($p,)* part, ),
749 layout: self.layout.intersect(part_layout),
750 dimension: self.dimension,
751 layout_tendency: self.layout_tendency + part_layout.tendency(),
752 }
753 }
754
755 /// Map and collect the results into a new array, which has the same size as the
756 /// inputs.
757 ///
758 /// If all inputs are c- or f-order respectively, that is preserved in the output.
759 pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
760 self.map_collect_owned(f)
761 }
762
763 pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
764 -> ArrayBase<S, D>
765 where S: DataOwned<Elem = R>
766 {
767 // safe because: all elements are written before the array is completed
768
769 let shape = self.dimension.clone().set_f(self.prefer_f());
770 let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
771 // Use partial to count the number of filled elements, and can drop the right
772 // number of elements on unwinding (if it happens during apply/collect).
773 unsafe {
774 let output_view = output.into_raw_view_mut().cast::<R>();
775 self.and(output_view)
776 .collect_with_partial(f)
777 .release_ownership();
778 }
779 });
780 unsafe {
781 output.assume_init()
782 }
783 }
784
785 /// Map and assign the results into the producer `into`, which should have the same
786 /// size as the other inputs.
787 ///
788 /// The producer should have assignable items as dictated by the `AssignElem` trait,
789 /// for example `&mut R`.
790 pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
791 where Q: IntoNdProducer<Dim=D>,
792 Q::Item: AssignElem<R>
793 {
794 self.and(into)
795 .for_each(move |$($p, )* output_| {
796 output_.assign_elem(f($($p ),*));
797 });
798 }
799
800 );
801
802 /// Split the `Zip` evenly in two.
803 ///
804 /// It will be split in the way that best preserves element locality.
805 pub fn split(self) -> (Self, Self) {
806 debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
807 debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
808 SplitPreference::split(self)
809 }
810 }
811
812 expand_if!(@bool [$notlast]
813 // For collect; Last producer is a RawViewMut
814 #[allow(non_snake_case)]
815 impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
816 where D: Dimension,
817 $($p: NdProducer<Dim=D> ,)*
818 PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
819 {
820 /// The inner workings of map_collect and par_map_collect
821 ///
822 /// Apply the function and collect the results into the output (last producer)
823 /// which should be a raw array view; a Partial that owns the written
824 /// elements is returned.
825 ///
826 /// Elements will be overwritten in place (in the sense of std::ptr::write).
827 ///
828 /// ## Safety
829 ///
830 /// The last producer is a RawArrayViewMut and must be safe to write into.
831 /// The producer must be c- or f-contig and have the same layout tendency
832 /// as the whole Zip.
833 ///
834 /// The returned Partial's proxy ownership of the elements must be handled,
835 /// before the array the raw view points to realizes its ownership.
836 pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
837 where F: FnMut($($p::Item,)* ) -> R
838 {
839 // Get the last producer; and make a Partial that aliases its data pointer
840 let (.., ref output) = &self.parts;
841
842 // debug assert that the output is contiguous in the memory layout we need
843 if cfg!(debug_assertions) {
844 let out_layout = output.layout();
845 assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
846 assert!(
847 (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
848 (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
849 "layout tendency violation for self layout {:?}, output layout {:?},\
850 output shape {:?}",
851 self.layout, out_layout, output.raw_dim());
852 }
853
854 let mut partial = Partial::new(output.as_ptr());
855
856 // Apply the mapping function on this zip
857 // if we panic with unwinding; Partial will drop the written elements.
858 let partial_len = &mut partial.len;
859 self.for_each(move |$($p,)* output_elem: *mut R| {
860 output_elem.write(f($($p),*));
861 if std::mem::needs_drop::<R>() {
862 *partial_len += 1;
863 }
864 });
865
866 partial
867 }
868 }
869 );
870
871 impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
872 where D: Dimension,
873 $($p: NdProducer<Dim=D> ,)*
874 {
875 fn can_split(&self) -> bool { self.size() > 1 }
876
877 fn split_preference(&self) -> (Axis, usize) {
878 // Always split in a way that preserves layout (if any)
879 let axis = self.max_stride_axis();
880 let index = self.len_of(axis) / 2;
881 (axis, index)
882 }
883 }
884
885 impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
886 where D: Dimension,
887 $($p: NdProducer<Dim=D> ,)*
888 {
889 fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
890 let (p1, p2) = self.parts.split_at(axis, index);
891 let (d1, d2) = self.dimension.split_at(axis, index);
892 (Zip {
893 dimension: d1,
894 layout: self.layout,
895 parts: p1,
896 layout_tendency: self.layout_tendency,
897 },
898 Zip {
899 dimension: d2,
900 layout: self.layout,
901 parts: p2,
902 layout_tendency: self.layout_tendency,
903 })
904 }
905
906 }
907
908 )+
909 };
910}
911
912map_impl! {
913 [true P1],
914 [true P1 P2],
915 [true P1 P2 P3],
916 [true P1 P2 P3 P4],
917 [true P1 P2 P3 P4 P5],
918 [false P1 P2 P3 P4 P5 P6],
919}
920
921/// Value controlling the execution of `.fold_while` on `Zip`.
922#[derive(Debug, Copy, Clone)]
923pub enum FoldWhile<T>
924{
925 /// Continue folding with this value
926 Continue(T),
927 /// Fold is complete and will return this value
928 Done(T),
929}
930
931impl<T> FoldWhile<T>
932{
933 /// Return the inner value
934 pub fn into_inner(self) -> T
935 {
936 match self {
937 FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
938 }
939 }
940
941 /// Return true if it is `Done`, false if `Continue`
942 pub fn is_done(&self) -> bool
943 {
944 match *self {
945 FoldWhile::Continue(_) => false,
946 FoldWhile::Done(_) => true,
947 }
948 }
949}