proptest/strategy/
recursive.rs

1//-
2// Copyright 2017 Jason Lingle
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10use crate::std_facade::{fmt, Arc, Box, Vec};
11
12use crate::strategy::traits::*;
13use crate::strategy::unions::float_to_weight;
14use crate::test_runner::*;
15
16/// Return type from `Strategy::prop_recursive()`.
17#[must_use = "strategies do nothing unless used"]
18pub struct Recursive<T, F> {
19    base: BoxedStrategy<T>,
20    recurse: Arc<F>,
21    depth: u32,
22    desired_size: u32,
23    expected_branch_size: u32,
24}
25
26impl<T: fmt::Debug, F> fmt::Debug for Recursive<T, F> {
27    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28        f.debug_struct("Recursive")
29            .field("base", &self.base)
30            .field("recurse", &"<function>")
31            .field("depth", &self.depth)
32            .field("desired_size", &self.desired_size)
33            .field("expected_branch_size", &self.expected_branch_size)
34            .finish()
35    }
36}
37
38impl<T, F> Clone for Recursive<T, F> {
39    fn clone(&self) -> Self {
40        Recursive {
41            base: self.base.clone(),
42            recurse: Arc::clone(&self.recurse),
43            depth: self.depth,
44            desired_size: self.desired_size,
45            expected_branch_size: self.expected_branch_size,
46        }
47    }
48}
49
50impl<
51        T: fmt::Debug + 'static,
52        R: Strategy<Value = T> + 'static,
53        F: Fn(BoxedStrategy<T>) -> R,
54    > Recursive<T, F>
55{
56    pub(super) fn new(
57        base: impl Strategy<Value = T> + 'static,
58        depth: u32,
59        desired_size: u32,
60        expected_branch_size: u32,
61        recurse: F,
62    ) -> Self {
63        Self {
64            base: base.boxed(),
65            recurse: Arc::new(recurse),
66            depth,
67            desired_size,
68            expected_branch_size,
69        }
70    }
71}
72
73impl<
74        T: fmt::Debug + 'static,
75        R: Strategy<Value = T> + 'static,
76        F: Fn(BoxedStrategy<T>) -> R,
77    > Strategy for Recursive<T, F>
78{
79    type Tree = Box<dyn ValueTree<Value = T>>;
80    type Value = T;
81
82    fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
83        // Since the generator is stateless, we can't implement any "absolutely
84        // X many items" rule. We _can_, however, with extremely high
85        // probability, obtain a value near what we want by using decaying
86        // probabilities of branching as we go down the tree.
87        //
88        // We are given a target size S and a branch size K (branch size =
89        // expected number of items immediately below each branch). We select
90        // some probability P for each level.
91        //
92        // A single level l is thus expected to hold PlK branches. Each of
93        // those will have P(l+1)K child branches of their own, so there are
94        // PlP(l+1)K² second-level branches. The total branches in the tree is
95        // thus (Σ PlK^l) for l from 0 to infinity. Each level is expected to
96        // hold K items, so the total number of items is simply K times the
97        // number of branches, or (K Σ PlK^l). So we want to find a P sequence
98        // such that (lim (K Σ PlK^l) = S), or more simply,
99        // (lim Σ PlK^l = S/K).
100        //
101        // Let Q be a second probability sequence such that Pl = Ql/K^l. This
102        // changes the formulation to (lim Σ Ql = S/K). The series Σ0.5^(l+1)
103        // converges on 1.0, so we can let Ql = S/K * 0.5^(l+1), and so
104        // Pl = S/K^(l+1) * 0.5^(l+1) = S / (2K) ^ (l+1)
105        //
106        // We don't actually have infinite levels here since we _can_ easily
107        // cap to a fixed max depth, so this will be a minor underestimate. We
108        // also clamp all probabilities to 0.9 to ensure that we can't end up
109        // with levels which are always pure branches, which further
110        // underestimates size.
111
112        let mut branch_probabilities = Vec::new();
113        let mut k2 = u64::from(self.expected_branch_size) * 2;
114        for _ in 0..self.depth {
115            branch_probabilities.push(f64::from(self.desired_size) / k2 as f64);
116            k2 = k2.saturating_mul(u64::from(self.expected_branch_size) * 2);
117        }
118
119        let mut strat = self.base.clone();
120        while let Some(branch_probability) = branch_probabilities.pop() {
121            let recursed = (self.recurse)(strat.clone());
122            let recursive_choice = recursed.boxed();
123            let non_recursive_choice = strat;
124            // Clamp the maximum branch probability to 0.9 to ensure we can
125            // generate non-recursive cases reasonably often.
126            let branch_probability = branch_probability.min(0.9);
127            let (weight_branch, weight_leaf) =
128                float_to_weight(branch_probability);
129            let branch = prop_oneof![
130                weight_leaf => non_recursive_choice,
131                weight_branch => recursive_choice,
132            ];
133            strat = branch.boxed();
134        }
135
136        strat.new_tree(runner)
137    }
138}
139
140#[cfg(test)]
141mod test {
142    use std::cmp::max;
143
144    use super::*;
145    use crate::strategy::just::Just;
146
147    #[derive(Clone, Debug, PartialEq)]
148    enum Tree {
149        Leaf,
150        Branch(Vec<Tree>),
151    }
152
153    impl Tree {
154        fn stats(&self) -> (u32, u32) {
155            match *self {
156                Tree::Leaf => (0, 1),
157                Tree::Branch(ref children) => {
158                    let mut depth = 0;
159                    let mut count = 0;
160                    for child in children {
161                        let (d, c) = child.stats();
162                        depth = max(d, depth);
163                        count += c;
164                    }
165
166                    (depth + 1, count + 1)
167                }
168            }
169        }
170    }
171
172    #[test]
173    fn test_recursive() {
174        let mut max_depth = 0;
175        let mut max_count = 0;
176
177        let strat = Just(Tree::Leaf).prop_recursive(4, 64, 16, |element| {
178            crate::collection::vec(element, 8..16).prop_map(Tree::Branch)
179        });
180
181        let mut runner = TestRunner::deterministic();
182        for _ in 0..65536 {
183            let tree = strat.new_tree(&mut runner).unwrap().current();
184            let (depth, count) = tree.stats();
185            assert!(depth <= 4, "Got depth {}", depth);
186            assert!(count <= 128, "Got count {}", count);
187            max_depth = max(depth, max_depth);
188            max_count = max(count, max_count);
189        }
190
191        assert!(max_depth >= 3, "Only got max depth {}", max_depth);
192        assert!(max_count > 48, "Only got max count {}", max_count);
193    }
194
195    #[test]
196    fn simplifies_to_non_recursive() {
197        let strat = Just(Tree::Leaf).prop_recursive(4, 64, 16, |element| {
198            crate::collection::vec(element, 8..16).prop_map(Tree::Branch)
199        });
200
201        let mut runner = TestRunner::deterministic();
202        for _ in 0..256 {
203            let mut value = strat.new_tree(&mut runner).unwrap();
204            while value.simplify() {}
205
206            assert_eq!(Tree::Leaf, value.current());
207        }
208    }
209}