rayon/iter/
fold_chunks.rs

1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6use crate::math::div_round_up;
7
8/// `FoldChunks` is an iterator that groups elements of an underlying iterator and applies a
9/// function over them, producing a single value for each group.
10///
11/// This struct is created by the [`fold_chunks()`] method on [`IndexedParallelIterator`]
12///
13/// [`fold_chunks()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16#[derive(Clone)]
17pub struct FoldChunks<I, ID, F>
18where
19    I: IndexedParallelIterator,
20{
21    base: I,
22    chunk_size: usize,
23    fold_op: F,
24    identity: ID,
25}
26
27impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("Fold")
30            .field("base", &self.base)
31            .field("chunk_size", &self.chunk_size)
32            .finish()
33    }
34}
35
36impl<I, ID, U, F> FoldChunks<I, ID, F>
37where
38    I: IndexedParallelIterator,
39    ID: Fn() -> U + Send + Sync,
40    F: Fn(U, I::Item) -> U + Send + Sync,
41    U: Send,
42{
43    /// Creates a new `FoldChunks` iterator
44    pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
45        FoldChunks {
46            base,
47            chunk_size,
48            identity,
49            fold_op,
50        }
51    }
52}
53
54impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
55where
56    I: IndexedParallelIterator,
57    ID: Fn() -> U + Send + Sync,
58    F: Fn(U, I::Item) -> U + Send + Sync,
59    U: Send,
60{
61    type Item = U;
62
63    fn drive_unindexed<C>(self, consumer: C) -> C::Result
64    where
65        C: Consumer<U>,
66    {
67        bridge(self, consumer)
68    }
69
70    fn opt_len(&self) -> Option<usize> {
71        Some(self.len())
72    }
73}
74
75impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
76where
77    I: IndexedParallelIterator,
78    ID: Fn() -> U + Send + Sync,
79    F: Fn(U, I::Item) -> U + Send + Sync,
80    U: Send,
81{
82    fn len(&self) -> usize {
83        div_round_up(self.base.len(), self.chunk_size)
84    }
85
86    fn drive<C>(self, consumer: C) -> C::Result
87    where
88        C: Consumer<Self::Item>,
89    {
90        bridge(self, consumer)
91    }
92
93    fn with_producer<CB>(self, callback: CB) -> CB::Output
94    where
95        CB: ProducerCallback<Self::Item>,
96    {
97        let len = self.base.len();
98        return self.base.with_producer(Callback {
99            chunk_size: self.chunk_size,
100            len,
101            identity: self.identity,
102            fold_op: self.fold_op,
103            callback,
104        });
105
106        struct Callback<CB, ID, F> {
107            chunk_size: usize,
108            len: usize,
109            identity: ID,
110            fold_op: F,
111            callback: CB,
112        }
113
114        impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
115        where
116            CB: ProducerCallback<U>,
117            ID: Fn() -> U + Send + Sync,
118            F: Fn(U, T) -> U + Send + Sync,
119        {
120            type Output = CB::Output;
121
122            fn callback<P>(self, base: P) -> CB::Output
123            where
124                P: Producer<Item = T>,
125            {
126                let identity = &self.identity;
127                let fold_op = &self.fold_op;
128                let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
129                let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
130                self.callback.callback(producer)
131            }
132        }
133    }
134}
135
136#[cfg(test)]
137mod test {
138    use super::*;
139    use std::ops::Add;
140
141    #[test]
142    fn check_fold_chunks() {
143        let words = "bishbashbosh!"
144            .chars()
145            .collect::<Vec<_>>()
146            .into_par_iter()
147            .fold_chunks(4, String::new, |mut s, c| {
148                s.push(c);
149                s
150            })
151            .collect::<Vec<_>>();
152
153        assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
154    }
155
156    // 'closure' values for tests below
157    fn id() -> i32 {
158        0
159    }
160    fn sum<T, U>(x: T, y: U) -> T
161    where
162        T: Add<U, Output = T>,
163    {
164        x + y
165    }
166
167    #[test]
168    #[should_panic(expected = "chunk_size must not be zero")]
169    fn check_fold_chunks_zero_size() {
170        let _: Vec<i32> = vec![1, 2, 3]
171            .into_par_iter()
172            .fold_chunks(0, id, sum)
173            .collect();
174    }
175
176    #[test]
177    fn check_fold_chunks_even_size() {
178        assert_eq!(
179            vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
180            (1..10)
181                .into_par_iter()
182                .fold_chunks(3, id, sum)
183                .collect::<Vec<i32>>()
184        );
185    }
186
187    #[test]
188    fn check_fold_chunks_empty() {
189        let v: Vec<i32> = vec![];
190        let expected: Vec<i32> = vec![];
191        assert_eq!(
192            expected,
193            v.into_par_iter()
194                .fold_chunks(2, id, sum)
195                .collect::<Vec<i32>>()
196        );
197    }
198
199    #[test]
200    fn check_fold_chunks_len() {
201        assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
202        assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
203        assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
204        assert_eq!(1, [1].par_iter().fold_chunks(3, id, sum).len());
205        assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
206    }
207
208    #[test]
209    fn check_fold_chunks_uneven() {
210        let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
211            ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
212            (vec![1], 5, vec![1]),
213            ((0..4).collect(), 3, vec![1 + 2, 3]),
214        ];
215
216        for (i, (v, n, expected)) in cases.into_iter().enumerate() {
217            let mut res: Vec<u32> = vec![];
218            v.par_iter()
219                .fold_chunks(n, || 0, sum)
220                .collect_into_vec(&mut res);
221            assert_eq!(expected, res, "Case {} failed", i);
222
223            res.truncate(0);
224            v.into_par_iter()
225                .fold_chunks(n, || 0, sum)
226                .rev()
227                .collect_into_vec(&mut res);
228            assert_eq!(
229                expected.into_iter().rev().collect::<Vec<u32>>(),
230                res,
231                "Case {} reversed failed",
232                i
233            );
234        }
235    }
236}