proptest/strategy/
unions.rs

1//-
2// Copyright 2017 Jason Lingle
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10use crate::std_facade::{fmt, Arc, Vec};
11use core::cmp::{max, min};
12use core::u32;
13
14#[cfg(not(feature = "std"))]
15use num_traits::float::FloatCore;
16
17use crate::num::sample_uniform;
18use crate::strategy::{lazy::LazyValueTree, traits::*};
19use crate::test_runner::*;
20
21/// A **relative** `weight` of a particular `Strategy` corresponding to `T`
22/// coupled with `T` itself. The weight is currently given in `u32`.
23pub type W<T> = (u32, T);
24
25/// A **relative** `weight` of a particular `Strategy` corresponding to `T`
26/// coupled with `Arc<T>`. The weight is currently given in `u32`.
27pub type WA<T> = (u32, Arc<T>);
28
29/// A `Strategy` which picks from one of several delegate `Strategy`s.
30///
31/// See `Strategy::prop_union()`.
32#[derive(Clone, Debug)]
33#[must_use = "strategies do nothing unless used"]
34pub struct Union<T: Strategy> {
35    // In principle T could be any `Strategy + Clone`, but that isn't possible
36    // for BC reasons with the 0.9 series.
37    options: Vec<WA<T>>,
38}
39
40impl<T: Strategy> Union<T> {
41    /// Create a strategy which selects uniformly from the given delegate
42    /// strategies.
43    ///
44    /// When shrinking, after maximal simplification of the chosen element, the
45    /// strategy will move to earlier options and continue simplification with
46    /// those.
47    ///
48    /// ## Panics
49    ///
50    /// Panics if `options` is empty.
51    pub fn new(options: impl IntoIterator<Item = T>) -> Self {
52        let options: Vec<WA<T>> =
53            options.into_iter().map(|v| (1, Arc::new(v))).collect();
54        assert!(!options.is_empty());
55        Self { options }
56    }
57
58    pub(crate) fn try_new<E>(
59        it: impl Iterator<Item = Result<T, E>>,
60    ) -> Result<Self, E> {
61        let options: Vec<WA<T>> = it
62            .map(|r| r.map(|v| (1, Arc::new(v))))
63            .collect::<Result<_, _>>()?;
64
65        assert!(!options.is_empty());
66        Ok(Self { options })
67    }
68
69    /// Create a strategy which selects from the given delegate strategies.
70    ///
71    /// Each strategy is assigned a non-zero weight which determines how
72    /// frequently that strategy is chosen. For example, a strategy with a
73    /// weight of 2 will be chosen twice as frequently as one with a weight of
74    /// 1\.
75    ///
76    /// ## Panics
77    ///
78    /// Panics if `options` is empty or any element has a weight of 0.
79    ///
80    /// Panics if the sum of the weights overflows a `u32`.
81    pub fn new_weighted(options: Vec<W<T>>) -> Self {
82        assert!(!options.is_empty());
83        assert!(
84            !options.iter().any(|&(w, _)| 0 == w),
85            "Union option has a weight of 0"
86        );
87        assert!(
88            options.iter().map(|&(w, _)| u64::from(w)).sum::<u64>()
89                <= u64::from(u32::MAX),
90            "Union weights overflow u32"
91        );
92        let options =
93            options.into_iter().map(|(w, v)| (w, Arc::new(v))).collect();
94        Self { options }
95    }
96
97    /// Add `other` as an additional alternate strategy with weight 1.
98    pub fn or(mut self, other: T) -> Self {
99        self.options.push((1, Arc::new(other)));
100        self
101    }
102}
103
104fn pick_weighted<I: Iterator<Item = u32>>(
105    runner: &mut TestRunner,
106    weights1: I,
107    weights2: I,
108) -> usize {
109    let sum = weights1.map(u64::from).sum();
110    let weighted_pick = sample_uniform(runner, 0, sum);
111    weights2
112        .scan(0u64, |state, w| {
113            *state += u64::from(w);
114            Some(*state)
115        })
116        .filter(|&v| v <= weighted_pick)
117        .count()
118}
119
120impl<T: Strategy> Strategy for Union<T> {
121    type Tree = UnionValueTree<T>;
122    type Value = T::Value;
123
124    fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
125        fn extract_weight<V>(&(w, _): &WA<V>) -> u32 {
126            w
127        }
128
129        let pick = pick_weighted(
130            runner,
131            self.options.iter().map(extract_weight::<T>),
132            self.options.iter().map(extract_weight::<T>),
133        );
134
135        let mut options = Vec::with_capacity(pick);
136
137        // Delay initialization for all options less than pick.
138        for option in &self.options[0..pick] {
139            options.push(LazyValueTree::new(Arc::clone(&option.1), runner));
140        }
141
142        // Initialize the tree at pick so at least one value is available. Note
143        // that if generation for the value at pick fails, the entire strategy
144        // will fail. This seems like the right call.
145        options.push(LazyValueTree::new_initialized(
146            self.options[pick].1.new_tree(runner)?,
147        ));
148
149        Ok(UnionValueTree {
150            options,
151            pick,
152            min_pick: 0,
153            prev_pick: None,
154        })
155    }
156}
157
158macro_rules! access_vec {
159    ([$($muta:tt)*] $dst:ident = $this:expr, $ix:expr, $body:block) => {{
160        let $dst = &$($muta)* $this.options[$ix];
161        $body
162    }}
163}
164
165/// `ValueTree` corresponding to `Union`.
166pub struct UnionValueTree<T: Strategy> {
167    options: Vec<LazyValueTree<T>>,
168    // This struct maintains the invariant that between function calls,
169    // `pick` and `prev_pick` (if Some) always point to initialized
170    // trees.
171    pick: usize,
172    min_pick: usize,
173    prev_pick: Option<usize>,
174}
175
176macro_rules! lazy_union_value_tree_body {
177    ($typ:ty, $access:ident) => {
178        type Value = $typ;
179
180        fn current(&self) -> Self::Value {
181            $access!([] opt = self, self.pick, {
182                opt.as_inner().unwrap_or_else(||
183                    panic!(
184                        "value tree at self.pick = {} must be initialized",
185                        self.pick,
186                    )
187                ).current()
188            })
189        }
190
191        fn simplify(&mut self) -> bool {
192            let orig_pick = self.pick;
193            if $access!([mut] opt = self, orig_pick, {
194                opt.as_inner_mut().unwrap_or_else(||
195                    panic!(
196                        "value tree at self.pick = {} must be initialized",
197                        orig_pick,
198                    )
199                ).simplify()
200            }) {
201                self.prev_pick = None;
202                return true;
203            }
204
205            assert!(
206                self.pick >= self.min_pick,
207                "self.pick = {} should never go below self.min_pick = {}",
208                self.pick,
209                self.min_pick,
210            );
211            if self.pick == self.min_pick {
212                // No more simplification to be done.
213                return false;
214            }
215
216            // self.prev_pick is always a valid pick.
217            self.prev_pick = Some(self.pick);
218
219            let mut next_pick = self.pick;
220            while next_pick > self.min_pick {
221                next_pick -= 1;
222                let initialized = $access!([mut] opt = self, next_pick, {
223                    opt.maybe_init();
224                    opt.is_initialized()
225                });
226                if initialized {
227                    // next_pick was correctly initialized above.
228                    self.pick = next_pick;
229                    return true;
230                }
231            }
232
233            false
234        }
235
236        fn complicate(&mut self) -> bool {
237            if let Some(pick) = self.prev_pick {
238                // simplify() ensures that the previous pick was initialized.
239                self.pick = pick;
240                self.min_pick = pick;
241                self.prev_pick = None;
242                true
243            } else {
244                let pick = self.pick;
245                $access!([mut] opt = self, pick, {
246                    opt.as_inner_mut().unwrap_or_else(||
247                        panic!(
248                            "value tree at self.pick = {} must be initialized",
249                            pick,
250                        )
251                    ).complicate()
252                })
253            }
254        }
255    }
256}
257
258impl<T: Strategy> ValueTree for UnionValueTree<T> {
259    lazy_union_value_tree_body!(T::Value, access_vec);
260}
261
262impl<T: Strategy> Clone for UnionValueTree<T>
263where
264    T::Tree: Clone,
265{
266    fn clone(&self) -> Self {
267        Self {
268            options: self.options.clone(),
269            pick: self.pick,
270            min_pick: self.min_pick,
271            prev_pick: self.prev_pick,
272        }
273    }
274}
275
276impl<T: Strategy> fmt::Debug for UnionValueTree<T>
277where
278    T::Tree: fmt::Debug,
279{
280    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
281        f.debug_struct("UnionValueTree")
282            .field("options", &self.options)
283            .field("pick", &self.pick)
284            .field("min_pick", &self.min_pick)
285            .field("prev_pick", &self.prev_pick)
286            .finish()
287    }
288}
289
290macro_rules! def_access_tuple {
291    ($b:tt $name:ident, $($n:tt)*) => {
292        macro_rules! $name {
293            ([$b($b muta:tt)*] $b dst:ident = $b this:expr,
294             $b ix:expr, $b body:block) => {
295                match $b ix {
296                    0 => {
297                        let $b dst = &$b($b muta)* $b this.options.0;
298                        $b body
299                    },
300                    $(
301                        $n => {
302                            if let Some(ref $b($b muta)* $b dst) =
303                                $b this.options.$n
304                            {
305                                $b body
306                            } else {
307                                panic!("TupleUnion tried to access \
308                                        uninitialised slot {}", $n)
309                            }
310                        },
311                    )*
312                    _ => panic!("TupleUnion tried to access out-of-range \
313                                 slot {}", $b ix),
314                }
315            }
316        }
317    }
318}
319
320def_access_tuple!($ access_tuple2, 1);
321def_access_tuple!($ access_tuple3, 1 2);
322def_access_tuple!($ access_tuple4, 1 2 3);
323def_access_tuple!($ access_tuple5, 1 2 3 4);
324def_access_tuple!($ access_tuple6, 1 2 3 4 5);
325def_access_tuple!($ access_tuple7, 1 2 3 4 5 6);
326def_access_tuple!($ access_tuple8, 1 2 3 4 5 6 7);
327def_access_tuple!($ access_tuple9, 1 2 3 4 5 6 7 8);
328def_access_tuple!($ access_tupleA, 1 2 3 4 5 6 7 8 9);
329
330/// Similar to `Union`, but internally uses a tuple to hold the strategies.
331///
332/// This allows better performance than vanilla `Union` since one does not need
333/// to resort to boxing and dynamic dispatch to handle heterogeneous
334/// strategies.
335///
336/// The difference between this and `TupleUnion` is that with this, value trees
337/// for variants that aren't picked at first are generated lazily.
338#[must_use = "strategies do nothing unless used"]
339#[derive(Clone, Copy, Debug)]
340pub struct TupleUnion<T>(T);
341
342impl<T> TupleUnion<T> {
343    /// Wrap `tuple` in a `TupleUnion`.
344    ///
345    /// The struct definition allows any `T` for `tuple`, but to be useful, it
346    /// must be a 2- to 10-tuple of `(u32, Arc<impl Strategy>)` pairs where all
347    /// strategies ultimately produce the same value. Each `u32` indicates the
348    /// relative weight of its corresponding strategy.
349    /// You may use `WA<S>` as an alias for `(u32, Arc<S>)`.
350    ///
351    /// Using this constructor directly is discouraged; prefer to use
352    /// `prop_oneof!` since it is generally clearer.
353    pub fn new(tuple: T) -> Self {
354        TupleUnion(tuple)
355    }
356}
357
358macro_rules! tuple_union {
359    ($($gen:ident $ix:tt)*) => {
360        impl<A : Strategy, $($gen: Strategy<Value = A::Value>),*>
361        Strategy for TupleUnion<(WA<A>, $(WA<$gen>),*)> {
362            type Tree = TupleUnionValueTree<
363                (LazyValueTree<A>, $(Option<LazyValueTree<$gen>>),*)>;
364            type Value = A::Value;
365
366            fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
367                let weights = [((self.0).0).0, $(((self.0).$ix).0),*];
368                let pick = pick_weighted(runner, weights.iter().cloned(),
369                                         weights.iter().cloned());
370
371                Ok(TupleUnionValueTree {
372                    options: (
373                        if 0 == pick {
374                            LazyValueTree::new_initialized(
375                                ((self.0).0).1.new_tree(runner)?)
376                        } else {
377                            LazyValueTree::new(
378                                Arc::clone(&((self.0).0).1), runner)
379                        },
380                        $(
381                        if $ix == pick {
382                            Some(LazyValueTree::new_initialized(
383                                 ((self.0).$ix).1.new_tree(runner)?))
384                        } else if $ix < pick {
385                            Some(LazyValueTree::new(
386                                    Arc::clone(&((self.0).$ix).1), runner))
387                        } else {
388                            None
389                        }),*),
390                    pick: pick,
391                    min_pick: 0,
392                    prev_pick: None,
393                })
394            }
395        }
396    }
397}
398
399tuple_union!(B 1);
400tuple_union!(B 1 C 2);
401tuple_union!(B 1 C 2 D 3);
402tuple_union!(B 1 C 2 D 3 E 4);
403tuple_union!(B 1 C 2 D 3 E 4 F 5);
404tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6);
405tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6 H 7);
406tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6 H 7 I 8);
407tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6 H 7 I 8 J 9);
408
409/// `ValueTree` type produced by `TupleUnion`.
410#[derive(Clone, Copy, Debug)]
411pub struct TupleUnionValueTree<T> {
412    options: T,
413    pick: usize,
414    min_pick: usize,
415    prev_pick: Option<usize>,
416}
417
418macro_rules! value_tree_tuple {
419    ($access:ident, $($gen:ident)*) => {
420        impl<A : Strategy, $($gen: Strategy<Value = A::Value>),*> ValueTree
421        for TupleUnionValueTree<
422            (LazyValueTree<A>, $(Option<LazyValueTree<$gen>>),*)
423        > {
424            lazy_union_value_tree_body!(A::Value, $access);
425        }
426    }
427}
428
429value_tree_tuple!(access_tuple2, B);
430value_tree_tuple!(access_tuple3, B C);
431value_tree_tuple!(access_tuple4, B C D);
432value_tree_tuple!(access_tuple5, B C D E);
433value_tree_tuple!(access_tuple6, B C D E F);
434value_tree_tuple!(access_tuple7, B C D E F G);
435value_tree_tuple!(access_tuple8, B C D E F G H);
436value_tree_tuple!(access_tuple9, B C D E F G H I);
437value_tree_tuple!(access_tupleA, B C D E F G H I J);
438
439const WEIGHT_BASE: u32 = 0x8000_0000;
440
441/// Convert a floating-point weight in the range (0.0,1.0) to a pair of weights
442/// that can be used with `Union` and similar.
443///
444/// The first return value is the weight corresponding to `f`; the second
445/// return value is the weight corresponding to `1.0 - f`.
446///
447/// This call does not make any guarantees as to what range of weights it may
448/// produce, except that adding the two return values will never overflow a
449/// `u32`. As such, it is generally not meaningful to combine any other weights
450/// with the two returned.
451///
452/// ## Panics
453///
454/// Panics if `f` is not a real number between 0.0 and 1.0, both exclusive.
455pub fn float_to_weight(f: f64) -> (u32, u32) {
456    assert!(f > 0.0 && f < 1.0, "Invalid probability: {}", f);
457
458    // Clamp to 1..WEIGHT_BASE-1 so that we never produce a weight of 0.
459    let pos = max(
460        1,
461        min(WEIGHT_BASE - 1, (f * f64::from(WEIGHT_BASE)).round() as u32),
462    );
463    let neg = WEIGHT_BASE - pos;
464
465    (pos, neg)
466}
467
468#[cfg(test)]
469mod test {
470    use super::*;
471    use crate::strategy::just::Just;
472
473    // FIXME(2018-06-01): figure out a way to run this test on no_std.
474    // The problem is that the default seed is fixed and does not produce
475    // enough passed tests. We need some universal source of non-determinism
476    // for the seed, which is unlikely.
477    #[cfg(feature = "std")]
478    #[test]
479    fn test_union() {
480        let input = (10u32..20u32).prop_union(30u32..40u32);
481        // Expect that 25% of cases pass (left input happens to be < 15, and
482        // left is chosen as initial value). Of the 75% that fail, 50% should
483        // converge to 15 and 50% to 30 (the latter because the left is beneath
484        // the passing threshold).
485        let mut passed = 0;
486        let mut converged_low = 0;
487        let mut converged_high = 0;
488        let mut runner = TestRunner::deterministic();
489        for _ in 0..256 {
490            let case = input.new_tree(&mut runner).unwrap();
491            let result = runner.run_one(case, |v| {
492                prop_assert!(v < 15);
493                Ok(())
494            });
495
496            match result {
497                Ok(true) => passed += 1,
498                Err(TestError::Fail(_, 15)) => converged_low += 1,
499                Err(TestError::Fail(_, 30)) => converged_high += 1,
500                e => panic!("Unexpected result: {:?}", e),
501            }
502        }
503
504        assert!(passed >= 32 && passed <= 96, "Bad passed count: {}", passed);
505        assert!(
506            converged_low >= 32 && converged_low <= 160,
507            "Bad converged_low count: {}",
508            converged_low
509        );
510        assert!(
511            converged_high >= 32 && converged_high <= 160,
512            "Bad converged_high count: {}",
513            converged_high
514        );
515    }
516
517    #[test]
518    fn test_union_weighted() {
519        let input = Union::new_weighted(vec![
520            (1, Just(0usize)),
521            (2, Just(1usize)),
522            (1, Just(2usize)),
523        ]);
524
525        let mut counts = [0, 0, 0];
526        let mut runner = TestRunner::deterministic();
527        for _ in 0..65536 {
528            counts[input.new_tree(&mut runner).unwrap().current()] += 1;
529        }
530
531        println!("{:?}", counts);
532        assert!(counts[0] > 0);
533        assert!(counts[2] > 0);
534        assert!(counts[1] > counts[0] * 3 / 2);
535        assert!(counts[1] > counts[2] * 3 / 2);
536    }
537
538    #[test]
539    fn test_union_sanity() {
540        check_strategy_sanity(
541            Union::new_weighted(vec![
542                (1, 0i32..100),
543                (2, 200i32..300),
544                (1, 400i32..500),
545            ]),
546            None,
547        );
548    }
549
550    // FIXME(2018-06-01): See note on `test_union`.
551    #[cfg(feature = "std")]
552    #[test]
553    fn test_tuple_union() {
554        let input = TupleUnion::new((
555            (1, Arc::new(10u32..20u32)),
556            (1, Arc::new(30u32..40u32)),
557        ));
558        // Expect that 25% of cases pass (left input happens to be < 15, and
559        // left is chosen as initial value). Of the 75% that fail, 50% should
560        // converge to 15 and 50% to 30 (the latter because the left is beneath
561        // the passing threshold).
562        let mut passed = 0;
563        let mut converged_low = 0;
564        let mut converged_high = 0;
565        let mut runner = TestRunner::deterministic();
566        for _ in 0..256 {
567            let case = input.new_tree(&mut runner).unwrap();
568            let result = runner.run_one(case, |v| {
569                prop_assert!(v < 15);
570                Ok(())
571            });
572
573            match result {
574                Ok(true) => passed += 1,
575                Err(TestError::Fail(_, 15)) => converged_low += 1,
576                Err(TestError::Fail(_, 30)) => converged_high += 1,
577                e => panic!("Unexpected result: {:?}", e),
578            }
579        }
580
581        assert!(passed >= 32 && passed <= 96, "Bad passed count: {}", passed);
582        assert!(
583            converged_low >= 32 && converged_low <= 160,
584            "Bad converged_low count: {}",
585            converged_low
586        );
587        assert!(
588            converged_high >= 32 && converged_high <= 160,
589            "Bad converged_high count: {}",
590            converged_high
591        );
592    }
593
594    #[test]
595    fn test_tuple_union_weighting() {
596        let input = TupleUnion::new((
597            (1, Arc::new(Just(0usize))),
598            (2, Arc::new(Just(1usize))),
599            (1, Arc::new(Just(2usize))),
600        ));
601
602        let mut counts = [0, 0, 0];
603        let mut runner = TestRunner::deterministic();
604        for _ in 0..65536 {
605            counts[input.new_tree(&mut runner).unwrap().current()] += 1;
606        }
607
608        println!("{:?}", counts);
609        assert!(counts[0] > 0);
610        assert!(counts[2] > 0);
611        assert!(counts[1] > counts[0] * 3 / 2);
612        assert!(counts[1] > counts[2] * 3 / 2);
613    }
614
615    #[test]
616    fn test_tuple_union_all_sizes() {
617        let mut runner = TestRunner::deterministic();
618        let r = Arc::new(1i32..10);
619
620        macro_rules! test {
621            ($($part:expr),*) => {{
622                let input = TupleUnion::new((
623                    $((1, $part.clone())),*,
624                    (1, Arc::new(Just(0i32)))
625                ));
626
627                let mut pass = false;
628                for _ in 0..1024 {
629                    if 0 == input.new_tree(&mut runner).unwrap().current() {
630                        pass = true;
631                        break;
632                    }
633                }
634
635                assert!(pass);
636            }}
637        }
638
639        test!(r); // 2
640        test!(r, r); // 3
641        test!(r, r, r); // 4
642        test!(r, r, r, r); // 5
643        test!(r, r, r, r, r); // 6
644        test!(r, r, r, r, r, r); // 7
645        test!(r, r, r, r, r, r, r); // 8
646        test!(r, r, r, r, r, r, r, r); // 9
647        test!(r, r, r, r, r, r, r, r, r); // 10
648    }
649
650    #[test]
651    fn test_tuple_union_sanity() {
652        check_strategy_sanity(
653            TupleUnion::new((
654                (1, Arc::new(0i32..100i32)),
655                (1, Arc::new(200i32..1000i32)),
656                (1, Arc::new(2000i32..3000i32)),
657            )),
658            None,
659        );
660    }
661
662    /// Test that unions work even if local filtering causes errors.
663    #[test]
664    fn test_filter_union_sanity() {
665        let filter_strategy = (0u32..256).prop_filter("!%5", |&v| 0 != v % 5);
666        check_strategy_sanity(
667            Union::new(vec![filter_strategy; 8]),
668            Some(filter_sanity_options()),
669        );
670    }
671
672    /// Test that tuple unions work even if local filtering causes errors.
673    #[test]
674    fn test_filter_tuple_union_sanity() {
675        let filter_strategy = (0u32..256).prop_filter("!%5", |&v| 0 != v % 5);
676        check_strategy_sanity(
677            TupleUnion::new((
678                (1, Arc::new(filter_strategy.clone())),
679                (1, Arc::new(filter_strategy.clone())),
680                (1, Arc::new(filter_strategy.clone())),
681                (1, Arc::new(filter_strategy.clone())),
682            )),
683            Some(filter_sanity_options()),
684        );
685    }
686
687    fn filter_sanity_options() -> CheckStrategySanityOptions {
688        CheckStrategySanityOptions {
689            // Due to internal rejection sampling, `simplify()` can
690            // converge back to what `complicate()` would do.
691            strict_complicate_after_simplify: false,
692            // Make failed filters return errors to test edge cases.
693            error_on_local_rejects: true,
694            ..CheckStrategySanityOptions::default()
695        }
696    }
697}