rayon/iter/
fold_chunks_with.rs

1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6use crate::math::div_round_up;
7
8/// `FoldChunksWith` is an iterator that groups elements of an underlying iterator and applies a
9/// function over them, producing a single value for each group.
10///
11/// This struct is created by the [`fold_chunks_with()`] method on [`IndexedParallelIterator`]
12///
13/// [`fold_chunks_with()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16#[derive(Clone)]
17pub struct FoldChunksWith<I, U, F>
18where
19    I: IndexedParallelIterator,
20{
21    base: I,
22    chunk_size: usize,
23    item: U,
24    fold_op: F,
25}
26
27impl<I: IndexedParallelIterator + Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("Fold")
30            .field("base", &self.base)
31            .field("chunk_size", &self.chunk_size)
32            .field("item", &self.item)
33            .finish()
34    }
35}
36
37impl<I, U, F> FoldChunksWith<I, U, F>
38where
39    I: IndexedParallelIterator,
40    U: Send + Clone,
41    F: Fn(U, I::Item) -> U + Send + Sync,
42{
43    /// Creates a new `FoldChunksWith` iterator
44    pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
45        FoldChunksWith {
46            base,
47            chunk_size,
48            item,
49            fold_op,
50        }
51    }
52}
53
54impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
55where
56    I: IndexedParallelIterator,
57    U: Send + Clone,
58    F: Fn(U, I::Item) -> U + Send + Sync,
59{
60    type Item = U;
61
62    fn drive_unindexed<C>(self, consumer: C) -> C::Result
63    where
64        C: Consumer<U>,
65    {
66        bridge(self, consumer)
67    }
68
69    fn opt_len(&self) -> Option<usize> {
70        Some(self.len())
71    }
72}
73
74impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
75where
76    I: IndexedParallelIterator,
77    U: Send + Clone,
78    F: Fn(U, I::Item) -> U + Send + Sync,
79{
80    fn len(&self) -> usize {
81        div_round_up(self.base.len(), self.chunk_size)
82    }
83
84    fn drive<C>(self, consumer: C) -> C::Result
85    where
86        C: Consumer<Self::Item>,
87    {
88        bridge(self, consumer)
89    }
90
91    fn with_producer<CB>(self, callback: CB) -> CB::Output
92    where
93        CB: ProducerCallback<Self::Item>,
94    {
95        let len = self.base.len();
96        return self.base.with_producer(Callback {
97            chunk_size: self.chunk_size,
98            len,
99            item: self.item,
100            fold_op: self.fold_op,
101            callback,
102        });
103
104        struct Callback<CB, T, F> {
105            chunk_size: usize,
106            len: usize,
107            item: T,
108            fold_op: F,
109            callback: CB,
110        }
111
112        impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
113        where
114            CB: ProducerCallback<U>,
115            U: Send + Clone,
116            F: Fn(U, T) -> U + Send + Sync,
117        {
118            type Output = CB::Output;
119
120            fn callback<P>(self, base: P) -> CB::Output
121            where
122                P: Producer<Item = T>,
123            {
124                let item = self.item;
125                let fold_op = &self.fold_op;
126                let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
127                let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
128                self.callback.callback(producer)
129            }
130        }
131    }
132}
133
134#[cfg(test)]
135mod test {
136    use super::*;
137    use std::ops::Add;
138
139    #[test]
140    fn check_fold_chunks_with() {
141        let words = "bishbashbosh!"
142            .chars()
143            .collect::<Vec<_>>()
144            .into_par_iter()
145            .fold_chunks_with(4, String::new(), |mut s, c| {
146                s.push(c);
147                s
148            })
149            .collect::<Vec<_>>();
150
151        assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
152    }
153
154    // 'closure' value for tests below
155    fn sum<T, U>(x: T, y: U) -> T
156    where
157        T: Add<U, Output = T>,
158    {
159        x + y
160    }
161
162    #[test]
163    #[should_panic(expected = "chunk_size must not be zero")]
164    fn check_fold_chunks_zero_size() {
165        let _: Vec<i32> = vec![1, 2, 3]
166            .into_par_iter()
167            .fold_chunks_with(0, 0, sum)
168            .collect();
169    }
170
171    #[test]
172    fn check_fold_chunks_even_size() {
173        assert_eq!(
174            vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
175            (1..10)
176                .into_par_iter()
177                .fold_chunks_with(3, 0, sum)
178                .collect::<Vec<i32>>()
179        );
180    }
181
182    #[test]
183    fn check_fold_chunks_with_empty() {
184        let v: Vec<i32> = vec![];
185        let expected: Vec<i32> = vec![];
186        assert_eq!(
187            expected,
188            v.into_par_iter()
189                .fold_chunks_with(2, 0, sum)
190                .collect::<Vec<i32>>()
191        );
192    }
193
194    #[test]
195    fn check_fold_chunks_len() {
196        assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
197        assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
198        assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
199        assert_eq!(1, [1].par_iter().fold_chunks_with(3, 0, sum).len());
200        assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
201    }
202
203    #[test]
204    fn check_fold_chunks_uneven() {
205        let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
206            ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
207            (vec![1], 5, vec![1]),
208            ((0..4).collect(), 3, vec![1 + 2, 3]),
209        ];
210
211        for (i, (v, n, expected)) in cases.into_iter().enumerate() {
212            let mut res: Vec<u32> = vec![];
213            v.par_iter()
214                .fold_chunks_with(n, 0, sum)
215                .collect_into_vec(&mut res);
216            assert_eq!(expected, res, "Case {} failed", i);
217
218            res.truncate(0);
219            v.into_par_iter()
220                .fold_chunks_with(n, 0, sum)
221                .rev()
222                .collect_into_vec(&mut res);
223            assert_eq!(
224                expected.into_iter().rev().collect::<Vec<u32>>(),
225                res,
226                "Case {} reversed failed",
227                i
228            );
229        }
230    }
231}