rayon/iter/
chain.rs

1use super::plumbing::*;
2use super::*;
3use rayon_core::join;
4use std::iter;
5
6/// `Chain` is an iterator that joins `b` after `a` in one continuous iterator.
7/// This struct is created by the [`chain()`] method on [`ParallelIterator`]
8///
9/// [`chain()`]: trait.ParallelIterator.html#method.chain
10/// [`ParallelIterator`]: trait.ParallelIterator.html
11#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
12#[derive(Debug, Clone)]
13pub struct Chain<A, B>
14where
15    A: ParallelIterator,
16    B: ParallelIterator<Item = A::Item>,
17{
18    a: A,
19    b: B,
20}
21
22impl<A, B> Chain<A, B>
23where
24    A: ParallelIterator,
25    B: ParallelIterator<Item = A::Item>,
26{
27    /// Creates a new `Chain` iterator.
28    pub(super) fn new(a: A, b: B) -> Self {
29        Chain { a, b }
30    }
31}
32
33impl<A, B> ParallelIterator for Chain<A, B>
34where
35    A: ParallelIterator,
36    B: ParallelIterator<Item = A::Item>,
37{
38    type Item = A::Item;
39
40    fn drive_unindexed<C>(self, consumer: C) -> C::Result
41    where
42        C: UnindexedConsumer<Self::Item>,
43    {
44        let Chain { a, b } = self;
45
46        // If we returned a value from our own `opt_len`, then the collect consumer in particular
47        // will balk at being treated like an actual `UnindexedConsumer`.  But when we do know the
48        // length, we can use `Consumer::split_at` instead, and this is still harmless for other
49        // truly-unindexed consumers too.
50        let (left, right, reducer) = if let Some(len) = a.opt_len() {
51            consumer.split_at(len)
52        } else {
53            let reducer = consumer.to_reducer();
54            (consumer.split_off_left(), consumer, reducer)
55        };
56
57        let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
58        reducer.reduce(a, b)
59    }
60
61    fn opt_len(&self) -> Option<usize> {
62        self.a.opt_len()?.checked_add(self.b.opt_len()?)
63    }
64}
65
66impl<A, B> IndexedParallelIterator for Chain<A, B>
67where
68    A: IndexedParallelIterator,
69    B: IndexedParallelIterator<Item = A::Item>,
70{
71    fn drive<C>(self, consumer: C) -> C::Result
72    where
73        C: Consumer<Self::Item>,
74    {
75        let Chain { a, b } = self;
76        let (left, right, reducer) = consumer.split_at(a.len());
77        let (a, b) = join(|| a.drive(left), || b.drive(right));
78        reducer.reduce(a, b)
79    }
80
81    fn len(&self) -> usize {
82        self.a.len().checked_add(self.b.len()).expect("overflow")
83    }
84
85    fn with_producer<CB>(self, callback: CB) -> CB::Output
86    where
87        CB: ProducerCallback<Self::Item>,
88    {
89        let a_len = self.a.len();
90        return self.a.with_producer(CallbackA {
91            callback,
92            a_len,
93            b: self.b,
94        });
95
96        struct CallbackA<CB, B> {
97            callback: CB,
98            a_len: usize,
99            b: B,
100        }
101
102        impl<CB, B> ProducerCallback<B::Item> for CallbackA<CB, B>
103        where
104            B: IndexedParallelIterator,
105            CB: ProducerCallback<B::Item>,
106        {
107            type Output = CB::Output;
108
109            fn callback<A>(self, a_producer: A) -> Self::Output
110            where
111                A: Producer<Item = B::Item>,
112            {
113                self.b.with_producer(CallbackB {
114                    callback: self.callback,
115                    a_len: self.a_len,
116                    a_producer,
117                })
118            }
119        }
120
121        struct CallbackB<CB, A> {
122            callback: CB,
123            a_len: usize,
124            a_producer: A,
125        }
126
127        impl<CB, A> ProducerCallback<A::Item> for CallbackB<CB, A>
128        where
129            A: Producer,
130            CB: ProducerCallback<A::Item>,
131        {
132            type Output = CB::Output;
133
134            fn callback<B>(self, b_producer: B) -> Self::Output
135            where
136                B: Producer<Item = A::Item>,
137            {
138                let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
139                self.callback.callback(producer)
140            }
141        }
142    }
143}
144
145/// ////////////////////////////////////////////////////////////////////////
146
147struct ChainProducer<A, B>
148where
149    A: Producer,
150    B: Producer<Item = A::Item>,
151{
152    a_len: usize,
153    a: A,
154    b: B,
155}
156
157impl<A, B> ChainProducer<A, B>
158where
159    A: Producer,
160    B: Producer<Item = A::Item>,
161{
162    fn new(a_len: usize, a: A, b: B) -> Self {
163        ChainProducer { a_len, a, b }
164    }
165}
166
167impl<A, B> Producer for ChainProducer<A, B>
168where
169    A: Producer,
170    B: Producer<Item = A::Item>,
171{
172    type Item = A::Item;
173    type IntoIter = ChainSeq<A::IntoIter, B::IntoIter>;
174
175    fn into_iter(self) -> Self::IntoIter {
176        ChainSeq::new(self.a.into_iter(), self.b.into_iter())
177    }
178
179    fn min_len(&self) -> usize {
180        Ord::max(self.a.min_len(), self.b.min_len())
181    }
182
183    fn max_len(&self) -> usize {
184        Ord::min(self.a.max_len(), self.b.max_len())
185    }
186
187    fn split_at(self, index: usize) -> (Self, Self) {
188        if index <= self.a_len {
189            let a_rem = self.a_len - index;
190            let (a_left, a_right) = self.a.split_at(index);
191            let (b_left, b_right) = self.b.split_at(0);
192            (
193                ChainProducer::new(index, a_left, b_left),
194                ChainProducer::new(a_rem, a_right, b_right),
195            )
196        } else {
197            let (a_left, a_right) = self.a.split_at(self.a_len);
198            let (b_left, b_right) = self.b.split_at(index - self.a_len);
199            (
200                ChainProducer::new(self.a_len, a_left, b_left),
201                ChainProducer::new(0, a_right, b_right),
202            )
203        }
204    }
205
206    fn fold_with<F>(self, mut folder: F) -> F
207    where
208        F: Folder<A::Item>,
209    {
210        folder = self.a.fold_with(folder);
211        if folder.full() {
212            folder
213        } else {
214            self.b.fold_with(folder)
215        }
216    }
217}
218
219/// ////////////////////////////////////////////////////////////////////////
220/// Wrapper for Chain to implement ExactSizeIterator
221
222struct ChainSeq<A, B> {
223    chain: iter::Chain<A, B>,
224}
225
226impl<A, B> ChainSeq<A, B> {
227    fn new(a: A, b: B) -> ChainSeq<A, B>
228    where
229        A: ExactSizeIterator,
230        B: ExactSizeIterator<Item = A::Item>,
231    {
232        ChainSeq { chain: a.chain(b) }
233    }
234}
235
236impl<A, B> Iterator for ChainSeq<A, B>
237where
238    A: Iterator,
239    B: Iterator<Item = A::Item>,
240{
241    type Item = A::Item;
242
243    fn next(&mut self) -> Option<Self::Item> {
244        self.chain.next()
245    }
246
247    fn size_hint(&self) -> (usize, Option<usize>) {
248        self.chain.size_hint()
249    }
250}
251
252impl<A, B> ExactSizeIterator for ChainSeq<A, B>
253where
254    A: ExactSizeIterator,
255    B: ExactSizeIterator<Item = A::Item>,
256{
257}
258
259impl<A, B> DoubleEndedIterator for ChainSeq<A, B>
260where
261    A: DoubleEndedIterator,
262    B: DoubleEndedIterator<Item = A::Item>,
263{
264    fn next_back(&mut self) -> Option<Self::Item> {
265        self.chain.next_back()
266    }
267}