ndarray/zip/
mod.rs

1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::partial::Partial;
18use crate::AssignElem;
19use crate::IntoDimension;
20use crate::Layout;
21
22use crate::dimension;
23use crate::indexes::{indices, Indices};
24use crate::split_at::{SplitAt, SplitPreference};
25
26pub use self::ndproducer::{IntoNdProducer, NdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30    ($e:expr) => {
31        match $e {
32            FoldWhile::Continue(x) => x,
33            x => return x,
34        }
35    };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting](ArrayBase#broadcasting) for more information.
41trait Broadcast<E>
42where E: IntoDimension
43{
44    type Output: NdProducer<Dim = E::Dim>;
45    /// Broadcast the array to the new dimensions `shape`.
46    ///
47    /// ***Panics*** if broadcasting isn’t possible.
48    #[track_caller]
49    fn broadcast_unwrap(self, shape: E) -> Self::Output;
50    private_decl! {}
51}
52
53/// Compute `Layout` hints for array shape dim, strides
54fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout
55{
56    let n = dim.ndim();
57    if dimension::is_layout_c(dim, strides) {
58        // effectively one-dimensional => C and F layout compatible
59        if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
60            Layout::one_dimensional()
61        } else {
62            Layout::c()
63        }
64    } else if n > 1 && dimension::is_layout_f(dim, strides) {
65        Layout::f()
66    } else if n > 1 {
67        if dim[0] > 1 && strides[0] == 1 {
68            Layout::fpref()
69        } else if dim[n - 1] > 1 && strides[n - 1] == 1 {
70            Layout::cpref()
71        } else {
72            Layout::none()
73        }
74    } else {
75        Layout::none()
76    }
77}
78
79impl<S, D> ArrayBase<S, D>
80where
81    S: RawData,
82    D: Dimension,
83{
84    pub(crate) fn layout_impl(&self) -> Layout
85    {
86        array_layout(&self.dim, &self.strides)
87    }
88}
89
90impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
91where
92    E: IntoDimension,
93    D: Dimension,
94{
95    type Output = ArrayView<'a, A, E::Dim>;
96    fn broadcast_unwrap(self, shape: E) -> Self::Output
97    {
98        #[allow(clippy::needless_borrow)]
99        let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
100        unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
101    }
102    private_impl! {}
103}
104
105trait ZippableTuple: Sized
106{
107    type Item;
108    type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
109    type Dim: Dimension;
110    type Stride: Copy;
111    fn as_ptr(&self) -> Self::Ptr;
112    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
113    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
114    fn stride_of(&self, index: usize) -> Self::Stride;
115    fn contiguous_stride(&self) -> Self::Stride;
116    fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
117}
118
119/// Lock step function application across several arrays or other producers.
120///
121/// Zip allows matching several producers to each other elementwise and applying
122/// a function over all tuples of elements (one item from each input at
123/// a time).
124///
125/// In general, the zip uses a tuple of producers
126/// ([`NdProducer`] trait) that all have to be of the
127/// same shape. The NdProducer implementation defines what its item type is
128/// (for example if it's a shared reference, mutable reference or an array
129/// view etc).
130///
131/// If all the input arrays are of the same memory layout the zip performs much
132/// better and the compiler can usually vectorize the loop (if applicable).
133///
134/// The order elements are visited is not specified. The producers don’t have to
135/// have the same item type.
136///
137/// The `Zip` has two methods for function application: `for_each` and
138/// `fold_while`. The zip object can be split, which allows parallelization.
139/// A read-only zip object (no mutable producers) can be cloned.
140///
141/// See also the [`azip!()`] which offers a convenient shorthand
142/// to common ways to use `Zip`.
143///
144/// ```
145/// use ndarray::Zip;
146/// use ndarray::Array2;
147///
148/// type M = Array2<f64>;
149///
150/// // Create four 2d arrays of the same size
151/// let mut a = M::zeros((64, 32));
152/// let b = M::from_elem(a.dim(), 1.);
153/// let c = M::from_elem(a.dim(), 2.);
154/// let d = M::from_elem(a.dim(), 3.);
155///
156/// // Example 1: Perform an elementwise arithmetic operation across
157/// // the four arrays a, b, c, d.
158///
159/// Zip::from(&mut a)
160///     .and(&b)
161///     .and(&c)
162///     .and(&d)
163///     .for_each(|w, &x, &y, &z| {
164///         *w += x + y * z;
165///     });
166///
167/// // Example 2: Create a new array `totals` with one entry per row of `a`.
168/// //  Use Zip to traverse the rows of `a` and assign to the corresponding
169/// //  entry in `totals` with the sum across each row.
170/// //  This is possible because the producer for `totals` and the row producer
171/// //  for `a` have the same shape and dimensionality.
172/// //  The rows producer yields one array view (`row`) per iteration.
173///
174/// use ndarray::{Array1, Axis};
175///
176/// let mut totals = Array1::zeros(a.nrows());
177///
178/// Zip::from(&mut totals)
179///     .and(a.rows())
180///     .for_each(|totals, row| *totals = row.sum());
181///
182/// // Check the result against the built in `.sum_axis()` along axis 1.
183/// assert_eq!(totals, a.sum_axis(Axis(1)));
184///
185///
186/// // Example 3: Recreate Example 2 using map_collect to make a new array
187///
188/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
189///
190/// // Check the result against the previous example.
191/// assert_eq!(totals, totals2);
192/// ```
193#[derive(Debug, Clone)]
194#[must_use = "zipping producers is lazy and does nothing unless consumed"]
195pub struct Zip<Parts, D>
196{
197    parts: Parts,
198    dimension: D,
199    layout: Layout,
200    /// The sum of the layout tendencies of the parts;
201    /// positive for c- and negative for f-layout preference.
202    layout_tendency: i32,
203}
204
205impl<P, D> Zip<(P,), D>
206where
207    D: Dimension,
208    P: NdProducer<Dim = D>,
209{
210    /// Create a new `Zip` from the input array or other producer `p`.
211    ///
212    /// The Zip will take the exact dimension of `p` and all inputs
213    /// must have the same dimensions (or be broadcast to them).
214    pub fn from<IP>(p: IP) -> Self
215    where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
216    {
217        let array = p.into_producer();
218        let dim = array.raw_dim();
219        let layout = array.layout();
220        Zip {
221            dimension: dim,
222            layout,
223            parts: (array,),
224            layout_tendency: layout.tendency(),
225        }
226    }
227}
228impl<P, D> Zip<(Indices<D>, P), D>
229where
230    D: Dimension + Copy,
231    P: NdProducer<Dim = D>,
232{
233    /// Create a new `Zip` with an index producer and the producer `p`.
234    ///
235    /// The Zip will take the exact dimension of `p` and all inputs
236    /// must have the same dimensions (or be broadcast to them).
237    ///
238    /// *Note:* Indexed zip has overhead.
239    pub fn indexed<IP>(p: IP) -> Self
240    where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
241    {
242        let array = p.into_producer();
243        let dim = array.raw_dim();
244        Zip::from(indices(dim)).and(array)
245    }
246}
247
248#[inline]
249fn zip_dimension_check<D, P>(dimension: &D, part: &P)
250where
251    D: Dimension,
252    P: NdProducer<Dim = D>,
253{
254    ndassert!(
255        part.equal_dim(dimension),
256        "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
257        dimension,
258        part.raw_dim()
259    );
260}
261
262impl<Parts, D> Zip<Parts, D>
263where D: Dimension
264{
265    /// Return a the number of element tuples in the Zip
266    pub fn size(&self) -> usize
267    {
268        self.dimension.size()
269    }
270
271    /// Return the length of `axis`
272    ///
273    /// ***Panics*** if `axis` is out of bounds.
274    #[track_caller]
275    fn len_of(&self, axis: Axis) -> usize
276    {
277        self.dimension[axis.index()]
278    }
279
280    fn prefer_f(&self) -> bool
281    {
282        !self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
283    }
284
285    /// Return an *approximation* to the max stride axis; if
286    /// component arrays disagree, there may be no choice better than the
287    /// others.
288    fn max_stride_axis(&self) -> Axis
289    {
290        let i = if self.prefer_f() {
291            self.dimension
292                .slice()
293                .iter()
294                .rposition(|&len| len > 1)
295                .unwrap_or(self.dimension.ndim() - 1)
296        } else {
297            /* corder or default */
298            self.dimension
299                .slice()
300                .iter()
301                .position(|&len| len > 1)
302                .unwrap_or(0)
303        };
304        Axis(i)
305    }
306}
307
308impl<P, D> Zip<P, D>
309where D: Dimension
310{
311    fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
312    where
313        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
314        P: ZippableTuple<Dim = D>,
315    {
316        if self.dimension.ndim() == 0 {
317            function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
318        } else if self.layout.is(Layout::CORDER | Layout::FORDER) {
319            self.for_each_core_contiguous(acc, function)
320        } else {
321            self.for_each_core_strided(acc, function)
322        }
323    }
324
325    fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
326    where
327        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
328        P: ZippableTuple<Dim = D>,
329    {
330        debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
331        let size = self.dimension.size();
332        let ptrs = self.parts.as_ptr();
333        let inner_strides = self.parts.contiguous_stride();
334        unsafe { self.inner(acc, ptrs, inner_strides, size, &mut function) }
335    }
336
337    /// The innermost loop of the Zip for_each methods
338    ///
339    /// Run the fold while operation on a stretch of elements with constant strides
340    ///
341    /// `ptr`: base pointer for the first element in this stretch
342    /// `strides`: strides for the elements in this stretch
343    /// `len`: number of elements
344    /// `function`: closure
345    unsafe fn inner<F, Acc>(
346        &self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, len: usize, function: &mut F,
347    ) -> FoldWhile<Acc>
348    where
349        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
350        P: ZippableTuple,
351    {
352        let mut i = 0;
353        while i < len {
354            let p = ptr.stride_offset(strides, i);
355            acc = fold_while!(function(acc, self.parts.as_ref(p)));
356            i += 1;
357        }
358        FoldWhile::Continue(acc)
359    }
360
361    fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
362    where
363        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
364        P: ZippableTuple<Dim = D>,
365    {
366        let n = self.dimension.ndim();
367        if n == 0 {
368            panic!("Unreachable: ndim == 0 is contiguous")
369        }
370        if n == 1 || self.layout_tendency >= 0 {
371            self.for_each_core_strided_c(acc, function)
372        } else {
373            self.for_each_core_strided_f(acc, function)
374        }
375    }
376
377    // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
378    fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
379    where
380        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
381        P: ZippableTuple<Dim = D>,
382    {
383        let n = self.dimension.ndim();
384        let unroll_axis = n - 1;
385        let inner_len = self.dimension[unroll_axis];
386        self.dimension[unroll_axis] = 1;
387        let mut index_ = self.dimension.first_index();
388        let inner_strides = self.parts.stride_of(unroll_axis);
389        // Loop unrolled over closest axis
390        while let Some(index) = index_ {
391            unsafe {
392                let ptr = self.parts.uget_ptr(&index);
393                acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
394            }
395
396            index_ = self.dimension.next_for(index);
397        }
398        FoldWhile::Continue(acc)
399    }
400
401    // Non-contiguous but preference for F - unroll over Axis(0)
402    fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
403    where
404        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
405        P: ZippableTuple<Dim = D>,
406    {
407        let unroll_axis = 0;
408        let inner_len = self.dimension[unroll_axis];
409        self.dimension[unroll_axis] = 1;
410        let index_ = self.dimension.first_index();
411        let inner_strides = self.parts.stride_of(unroll_axis);
412        // Loop unrolled over closest axis
413        if let Some(mut index) = index_ {
414            loop {
415                unsafe {
416                    let ptr = self.parts.uget_ptr(&index);
417                    acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
418                }
419
420                if !self.dimension.next_for_f(&mut index) {
421                    break;
422                }
423            }
424        }
425        FoldWhile::Continue(acc)
426    }
427
428    #[cfg(feature = "rayon")]
429    pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
430    {
431        let is_f = self.prefer_f();
432        Array::uninit(self.dimension.clone().set_f(is_f))
433    }
434}
435
436impl<D, P1, P2> Zip<(P1, P2), D>
437where
438    D: Dimension,
439    P1: NdProducer<Dim = D>,
440    P1: NdProducer<Dim = D>,
441{
442    /// Debug assert traversal order is like c (including 1D case)
443    // Method placement: only used for binary Zip at the moment.
444    #[inline]
445    pub(crate) fn debug_assert_c_order(self) -> Self
446    {
447        debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
448                      self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
449                      "Assertion failed: traversal is not c-order or 1D for \
450                      layout {:?}, tendency {}, dimension {:?}",
451                      self.layout, self.layout_tendency, self.dimension);
452        self
453    }
454}
455
456/*
457trait Offset : Copy {
458    unsafe fn offset(self, off: isize) -> Self;
459    unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
460        self.offset(index as isize * stride)
461    }
462}
463
464impl<T> Offset for *mut T {
465    unsafe fn offset(self, off: isize) -> Self {
466        self.offset(off)
467    }
468}
469*/
470
471trait OffsetTuple
472{
473    type Args;
474    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
475}
476
477impl<T> OffsetTuple for *mut T
478{
479    type Args = isize;
480    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self
481    {
482        self.offset(index as isize * stride)
483    }
484}
485
486macro_rules! offset_impl {
487    ($([$($param:ident)*][ $($q:ident)*],)+) => {
488        $(
489        #[allow(non_snake_case)]
490        impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
491            type Args = ($($param::Stride,)*);
492            unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
493                let ($($param, )*) = self;
494                let ($($q, )*) = stride;
495                ($(Offset::stride_offset($param, $q, index),)*)
496            }
497        }
498        )+
499    };
500}
501
502offset_impl! {
503    [A ][ a],
504    [A B][ a b],
505    [A B C][ a b c],
506    [A B C D][ a b c d],
507    [A B C D E][ a b c d e],
508    [A B C D E F][ a b c d e f],
509}
510
511macro_rules! zipt_impl {
512    ($([$($p:ident)*][ $($q:ident)*],)+) => {
513        $(
514        #[allow(non_snake_case)]
515        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
516            type Item = ($($p::Item, )*);
517            type Ptr = ($($p::Ptr, )*);
518            type Dim = Dim;
519            type Stride = ($($p::Stride,)* );
520
521            fn stride_of(&self, index: usize) -> Self::Stride {
522                let ($(ref $p,)*) = *self;
523                ($($p.stride_of(Axis(index)), )*)
524            }
525
526            fn contiguous_stride(&self) -> Self::Stride {
527                let ($(ref $p,)*) = *self;
528                ($($p.contiguous_stride(), )*)
529            }
530
531            fn as_ptr(&self) -> Self::Ptr {
532                let ($(ref $p,)*) = *self;
533                ($($p.as_ptr(), )*)
534            }
535            unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
536                let ($(ref $q ,)*) = *self;
537                let ($($p,)*) = ptr;
538                ($($q.as_ref($p),)*)
539            }
540
541            unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
542                let ($(ref $p,)*) = *self;
543                ($($p.uget_ptr(i), )*)
544            }
545
546            fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
547                let ($($p,)*) = self;
548                let ($($p,)*) = (
549                    $($p.split_at(axis, index), )*
550                );
551                (
552                    ($($p.0,)*),
553                    ($($p.1,)*)
554                )
555            }
556        }
557        )+
558    };
559}
560
561zipt_impl! {
562    [A ][ a],
563    [A B][ a b],
564    [A B C][ a b c],
565    [A B C D][ a b c d],
566    [A B C D E][ a b c d e],
567    [A B C D E F][ a b c d e f],
568}
569
570macro_rules! map_impl {
571    ($([$notlast:ident $($p:ident)*],)+) => {
572        $(
573        #[allow(non_snake_case)]
574        impl<D, $($p),*> Zip<($($p,)*), D>
575            where D: Dimension,
576                  $($p: NdProducer<Dim=D> ,)*
577        {
578            /// Apply a function to all elements of the input arrays,
579            /// visiting elements in lock step.
580            pub fn for_each<F>(mut self, mut function: F)
581                where F: FnMut($($p::Item),*)
582            {
583                self.for_each_core((), move |(), args| {
584                    let ($($p,)*) = args;
585                    FoldWhile::Continue(function($($p),*))
586                });
587            }
588
589            /// Apply a fold function to all elements of the input arrays,
590            /// visiting elements in lock step.
591            ///
592            /// # Example
593            ///
594            /// The expression `tr(AᵀB)` can be more efficiently computed as
595            /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
596            /// elements of the entry-wise product). It would be possible to
597            /// evaluate this expression by first computing the entry-wise
598            /// product, `A∘B`, and then computing the elementwise sum of that
599            /// product, but it's possible to do this in a single loop (and
600            /// avoid an extra heap allocation if `A` and `B` can't be
601            /// consumed) by using `Zip`:
602            ///
603            /// ```
604            /// use ndarray::{array, Zip};
605            ///
606            /// let a = array![[1, 5], [3, 7]];
607            /// let b = array![[2, 4], [8, 6]];
608            ///
609            /// // Without using `Zip`. This involves two loops and an extra
610            /// // heap allocation for the result of `&a * &b`.
611            /// let sum_prod_nonzip = (&a * &b).sum();
612            /// // Using `Zip`. This is a single loop without any heap allocations.
613            /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
614            ///
615            /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
616            /// ```
617            pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
618            where
619                F: FnMut(Acc, $($p::Item),*) -> Acc,
620            {
621                self.for_each_core(acc, move |acc, args| {
622                    let ($($p,)*) = args;
623                    FoldWhile::Continue(function(acc, $($p),*))
624                }).into_inner()
625            }
626
627            /// Apply a fold function to the input arrays while the return
628            /// value is `FoldWhile::Continue`, visiting elements in lock step.
629            ///
630            pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
631                -> FoldWhile<Acc>
632                where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
633            {
634                self.for_each_core(acc, move |acc, args| {
635                    let ($($p,)*) = args;
636                    function(acc, $($p),*)
637                })
638            }
639
640            /// Tests if every element of the iterator matches a predicate.
641            ///
642            /// Returns `true` if `predicate` evaluates to `true` for all elements.
643            /// Returns `true` if the input arrays are empty.
644            ///
645            /// Example:
646            ///
647            /// ```
648            /// use ndarray::{array, Zip};
649            /// let a = array![1, 2, 3];
650            /// let b = array![1, 4, 9];
651            /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
652            /// ```
653            pub fn all<F>(mut self, mut predicate: F) -> bool
654                where F: FnMut($($p::Item),*) -> bool
655            {
656                !self.for_each_core((), move |_, args| {
657                    let ($($p,)*) = args;
658                    if predicate($($p),*) {
659                        FoldWhile::Continue(())
660                    } else {
661                        FoldWhile::Done(())
662                    }
663                }).is_done()
664            }
665
666            /// Tests if at least one element of the iterator matches a predicate.
667            ///
668            /// Returns `true` if `predicate` evaluates to `true` for at least one element.
669            /// Returns `false` if the input arrays are empty.
670            ///
671            /// Example:
672            ///
673            /// ```
674            /// use ndarray::{array, Zip};
675            /// let a = array![1, 2, 3];
676            /// let b = array![1, 4, 9];
677            /// assert!(Zip::from(&a).and(&b).any(|&a, &b| a == b));
678            /// assert!(!Zip::from(&a).and(&b).any(|&a, &b| a - 1 == b));
679            /// ```
680            pub fn any<F>(mut self, mut predicate: F) -> bool
681                where F: FnMut($($p::Item),*) -> bool
682            {
683                self.for_each_core((), move |_, args| {
684                    let ($($p,)*) = args;
685                    if predicate($($p),*) {
686                        FoldWhile::Done(())
687                    } else {
688                        FoldWhile::Continue(())
689                    }
690                }).is_done()
691            }
692
693            expand_if!(@bool [$notlast]
694
695            /// Include the producer `p` in the Zip.
696            ///
697            /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
698            #[track_caller]
699            pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
700                where P: IntoNdProducer<Dim=D>,
701            {
702                let part = p.into_producer();
703                zip_dimension_check(&self.dimension, &part);
704                self.build_and(part)
705            }
706
707            /// Include the producer `p` in the Zip.
708            ///
709            /// ## Safety
710            ///
711            /// The caller must ensure that the producer's shape is equal to the Zip's shape.
712            /// Uses assertions when debug assertions are enabled.
713            #[allow(unused)]
714            pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
715                where P: IntoNdProducer<Dim=D>,
716            {
717                #[cfg(debug_assertions)]
718                {
719                    self.and(p)
720                }
721                #[cfg(not(debug_assertions))]
722                {
723                    self.build_and(p.into_producer())
724                }
725            }
726
727            /// Include the producer `p` in the Zip, broadcasting if needed.
728            ///
729            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
730            ///
731            /// ***Panics*** if broadcasting isn’t possible.
732            #[track_caller]
733            pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
734                -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
735                where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
736                      D2: Dimension,
737            {
738                let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
739                self.build_and(part)
740            }
741
742            fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
743                where P: NdProducer<Dim=D>,
744            {
745                let part_layout = part.layout();
746                let ($($p,)*) = self.parts;
747                Zip {
748                    parts: ($($p,)* part, ),
749                    layout: self.layout.intersect(part_layout),
750                    dimension: self.dimension,
751                    layout_tendency: self.layout_tendency + part_layout.tendency(),
752                }
753            }
754
755            /// Map and collect the results into a new array, which has the same size as the
756            /// inputs.
757            ///
758            /// If all inputs are c- or f-order respectively, that is preserved in the output.
759            pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
760                self.map_collect_owned(f)
761            }
762
763            pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
764                -> ArrayBase<S, D>
765                where S: DataOwned<Elem = R>
766            {
767                // safe because: all elements are written before the array is completed
768
769                let shape = self.dimension.clone().set_f(self.prefer_f());
770                let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
771                    // Use partial to count the number of filled elements, and can drop the right
772                    // number of elements on unwinding (if it happens during apply/collect).
773                    unsafe {
774                        let output_view = output.into_raw_view_mut().cast::<R>();
775                        self.and(output_view)
776                            .collect_with_partial(f)
777                            .release_ownership();
778                    }
779                });
780                unsafe {
781                    output.assume_init()
782                }
783            }
784
785            /// Map and assign the results into the producer `into`, which should have the same
786            /// size as the other inputs.
787            ///
788            /// The producer should have assignable items as dictated by the `AssignElem` trait,
789            /// for example `&mut R`.
790            pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
791                where Q: IntoNdProducer<Dim=D>,
792                      Q::Item: AssignElem<R>
793            {
794                self.and(into)
795                    .for_each(move |$($p, )* output_| {
796                        output_.assign_elem(f($($p ),*));
797                    });
798            }
799
800            );
801
802            /// Split the `Zip` evenly in two.
803            ///
804            /// It will be split in the way that best preserves element locality.
805            pub fn split(self) -> (Self, Self) {
806                debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
807                debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
808                SplitPreference::split(self)
809            }
810        }
811
812        expand_if!(@bool [$notlast]
813            // For collect; Last producer is a RawViewMut
814            #[allow(non_snake_case)]
815            impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
816                where D: Dimension,
817                      $($p: NdProducer<Dim=D> ,)*
818                      PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
819            {
820                /// The inner workings of map_collect and par_map_collect
821                ///
822                /// Apply the function and collect the results into the output (last producer)
823                /// which should be a raw array view; a Partial that owns the written
824                /// elements is returned.
825                ///
826                /// Elements will be overwritten in place (in the sense of std::ptr::write).
827                ///
828                /// ## Safety
829                ///
830                /// The last producer is a RawArrayViewMut and must be safe to write into.
831                /// The producer must be c- or f-contig and have the same layout tendency
832                /// as the whole Zip.
833                ///
834                /// The returned Partial's proxy ownership of the elements must be handled,
835                /// before the array the raw view points to realizes its ownership.
836                pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
837                    where F: FnMut($($p::Item,)* ) -> R
838                {
839                    // Get the last producer; and make a Partial that aliases its data pointer
840                    let (.., ref output) = &self.parts;
841
842                    // debug assert that the output is contiguous in the memory layout we need
843                    if cfg!(debug_assertions) {
844                        let out_layout = output.layout();
845                        assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
846                        assert!(
847                            (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
848                            (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
849                            "layout tendency violation for self layout {:?}, output layout {:?},\
850                            output shape {:?}",
851                            self.layout, out_layout, output.raw_dim());
852                    }
853
854                    let mut partial = Partial::new(output.as_ptr());
855
856                    // Apply the mapping function on this zip
857                    // if we panic with unwinding; Partial will drop the written elements.
858                    let partial_len = &mut partial.len;
859                    self.for_each(move |$($p,)* output_elem: *mut R| {
860                        output_elem.write(f($($p),*));
861                        if std::mem::needs_drop::<R>() {
862                            *partial_len += 1;
863                        }
864                    });
865
866                    partial
867                }
868            }
869        );
870
871        impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
872            where D: Dimension,
873                  $($p: NdProducer<Dim=D> ,)*
874        {
875            fn can_split(&self) -> bool { self.size() > 1 }
876
877            fn split_preference(&self) -> (Axis, usize) {
878                // Always split in a way that preserves layout (if any)
879                let axis = self.max_stride_axis();
880                let index = self.len_of(axis) / 2;
881                (axis, index)
882            }
883        }
884
885        impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
886            where D: Dimension,
887                  $($p: NdProducer<Dim=D> ,)*
888        {
889            fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
890                let (p1, p2) = self.parts.split_at(axis, index);
891                let (d1, d2) = self.dimension.split_at(axis, index);
892                (Zip {
893                    dimension: d1,
894                    layout: self.layout,
895                    parts: p1,
896                    layout_tendency: self.layout_tendency,
897                },
898                Zip {
899                    dimension: d2,
900                    layout: self.layout,
901                    parts: p2,
902                    layout_tendency: self.layout_tendency,
903                })
904            }
905
906        }
907
908        )+
909    };
910}
911
912map_impl! {
913    [true P1],
914    [true P1 P2],
915    [true P1 P2 P3],
916    [true P1 P2 P3 P4],
917    [true P1 P2 P3 P4 P5],
918    [false P1 P2 P3 P4 P5 P6],
919}
920
921/// Value controlling the execution of `.fold_while` on `Zip`.
922#[derive(Debug, Copy, Clone)]
923pub enum FoldWhile<T>
924{
925    /// Continue folding with this value
926    Continue(T),
927    /// Fold is complete and will return this value
928    Done(T),
929}
930
931impl<T> FoldWhile<T>
932{
933    /// Return the inner value
934    pub fn into_inner(self) -> T
935    {
936        match self {
937            FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
938        }
939    }
940
941    /// Return true if it is `Done`, false if `Continue`
942    pub fn is_done(&self) -> bool
943    {
944        match *self {
945            FoldWhile::Continue(_) => false,
946            FoldWhile::Done(_) => true,
947        }
948    }
949}