1use crate::std_facade::{fmt, Arc, Vec};
11use core::cmp::{max, min};
12use core::u32;
13
14#[cfg(not(feature = "std"))]
15use num_traits::float::FloatCore;
16
17use crate::num::sample_uniform;
18use crate::strategy::{lazy::LazyValueTree, traits::*};
19use crate::test_runner::*;
20
21pub type W<T> = (u32, T);
24
25pub type WA<T> = (u32, Arc<T>);
28
29#[derive(Clone, Debug)]
33#[must_use = "strategies do nothing unless used"]
34pub struct Union<T: Strategy> {
35 options: Vec<WA<T>>,
38}
39
40impl<T: Strategy> Union<T> {
41 pub fn new(options: impl IntoIterator<Item = T>) -> Self {
52 let options: Vec<WA<T>> =
53 options.into_iter().map(|v| (1, Arc::new(v))).collect();
54 assert!(!options.is_empty());
55 Self { options }
56 }
57
58 pub(crate) fn try_new<E>(
59 it: impl Iterator<Item = Result<T, E>>,
60 ) -> Result<Self, E> {
61 let options: Vec<WA<T>> = it
62 .map(|r| r.map(|v| (1, Arc::new(v))))
63 .collect::<Result<_, _>>()?;
64
65 assert!(!options.is_empty());
66 Ok(Self { options })
67 }
68
69 pub fn new_weighted(options: Vec<W<T>>) -> Self {
82 assert!(!options.is_empty());
83 assert!(
84 !options.iter().any(|&(w, _)| 0 == w),
85 "Union option has a weight of 0"
86 );
87 assert!(
88 options.iter().map(|&(w, _)| u64::from(w)).sum::<u64>()
89 <= u64::from(u32::MAX),
90 "Union weights overflow u32"
91 );
92 let options =
93 options.into_iter().map(|(w, v)| (w, Arc::new(v))).collect();
94 Self { options }
95 }
96
97 pub fn or(mut self, other: T) -> Self {
99 self.options.push((1, Arc::new(other)));
100 self
101 }
102}
103
104fn pick_weighted<I: Iterator<Item = u32>>(
105 runner: &mut TestRunner,
106 weights1: I,
107 weights2: I,
108) -> usize {
109 let sum = weights1.map(u64::from).sum();
110 let weighted_pick = sample_uniform(runner, 0, sum);
111 weights2
112 .scan(0u64, |state, w| {
113 *state += u64::from(w);
114 Some(*state)
115 })
116 .filter(|&v| v <= weighted_pick)
117 .count()
118}
119
120impl<T: Strategy> Strategy for Union<T> {
121 type Tree = UnionValueTree<T>;
122 type Value = T::Value;
123
124 fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
125 fn extract_weight<V>(&(w, _): &WA<V>) -> u32 {
126 w
127 }
128
129 let pick = pick_weighted(
130 runner,
131 self.options.iter().map(extract_weight::<T>),
132 self.options.iter().map(extract_weight::<T>),
133 );
134
135 let mut options = Vec::with_capacity(pick);
136
137 for option in &self.options[0..pick] {
139 options.push(LazyValueTree::new(Arc::clone(&option.1), runner));
140 }
141
142 options.push(LazyValueTree::new_initialized(
146 self.options[pick].1.new_tree(runner)?,
147 ));
148
149 Ok(UnionValueTree {
150 options,
151 pick,
152 min_pick: 0,
153 prev_pick: None,
154 })
155 }
156}
157
158macro_rules! access_vec {
159 ([$($muta:tt)*] $dst:ident = $this:expr, $ix:expr, $body:block) => {{
160 let $dst = &$($muta)* $this.options[$ix];
161 $body
162 }}
163}
164
165pub struct UnionValueTree<T: Strategy> {
167 options: Vec<LazyValueTree<T>>,
168 pick: usize,
172 min_pick: usize,
173 prev_pick: Option<usize>,
174}
175
176macro_rules! lazy_union_value_tree_body {
177 ($typ:ty, $access:ident) => {
178 type Value = $typ;
179
180 fn current(&self) -> Self::Value {
181 $access!([] opt = self, self.pick, {
182 opt.as_inner().unwrap_or_else(||
183 panic!(
184 "value tree at self.pick = {} must be initialized",
185 self.pick,
186 )
187 ).current()
188 })
189 }
190
191 fn simplify(&mut self) -> bool {
192 let orig_pick = self.pick;
193 if $access!([mut] opt = self, orig_pick, {
194 opt.as_inner_mut().unwrap_or_else(||
195 panic!(
196 "value tree at self.pick = {} must be initialized",
197 orig_pick,
198 )
199 ).simplify()
200 }) {
201 self.prev_pick = None;
202 return true;
203 }
204
205 assert!(
206 self.pick >= self.min_pick,
207 "self.pick = {} should never go below self.min_pick = {}",
208 self.pick,
209 self.min_pick,
210 );
211 if self.pick == self.min_pick {
212 return false;
214 }
215
216 self.prev_pick = Some(self.pick);
218
219 let mut next_pick = self.pick;
220 while next_pick > self.min_pick {
221 next_pick -= 1;
222 let initialized = $access!([mut] opt = self, next_pick, {
223 opt.maybe_init();
224 opt.is_initialized()
225 });
226 if initialized {
227 self.pick = next_pick;
229 return true;
230 }
231 }
232
233 false
234 }
235
236 fn complicate(&mut self) -> bool {
237 if let Some(pick) = self.prev_pick {
238 self.pick = pick;
240 self.min_pick = pick;
241 self.prev_pick = None;
242 true
243 } else {
244 let pick = self.pick;
245 $access!([mut] opt = self, pick, {
246 opt.as_inner_mut().unwrap_or_else(||
247 panic!(
248 "value tree at self.pick = {} must be initialized",
249 pick,
250 )
251 ).complicate()
252 })
253 }
254 }
255 }
256}
257
258impl<T: Strategy> ValueTree for UnionValueTree<T> {
259 lazy_union_value_tree_body!(T::Value, access_vec);
260}
261
262impl<T: Strategy> Clone for UnionValueTree<T>
263where
264 T::Tree: Clone,
265{
266 fn clone(&self) -> Self {
267 Self {
268 options: self.options.clone(),
269 pick: self.pick,
270 min_pick: self.min_pick,
271 prev_pick: self.prev_pick,
272 }
273 }
274}
275
276impl<T: Strategy> fmt::Debug for UnionValueTree<T>
277where
278 T::Tree: fmt::Debug,
279{
280 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
281 f.debug_struct("UnionValueTree")
282 .field("options", &self.options)
283 .field("pick", &self.pick)
284 .field("min_pick", &self.min_pick)
285 .field("prev_pick", &self.prev_pick)
286 .finish()
287 }
288}
289
290macro_rules! def_access_tuple {
291 ($b:tt $name:ident, $($n:tt)*) => {
292 macro_rules! $name {
293 ([$b($b muta:tt)*] $b dst:ident = $b this:expr,
294 $b ix:expr, $b body:block) => {
295 match $b ix {
296 0 => {
297 let $b dst = &$b($b muta)* $b this.options.0;
298 $b body
299 },
300 $(
301 $n => {
302 if let Some(ref $b($b muta)* $b dst) =
303 $b this.options.$n
304 {
305 $b body
306 } else {
307 panic!("TupleUnion tried to access \
308 uninitialised slot {}", $n)
309 }
310 },
311 )*
312 _ => panic!("TupleUnion tried to access out-of-range \
313 slot {}", $b ix),
314 }
315 }
316 }
317 }
318}
319
320def_access_tuple!($ access_tuple2, 1);
321def_access_tuple!($ access_tuple3, 1 2);
322def_access_tuple!($ access_tuple4, 1 2 3);
323def_access_tuple!($ access_tuple5, 1 2 3 4);
324def_access_tuple!($ access_tuple6, 1 2 3 4 5);
325def_access_tuple!($ access_tuple7, 1 2 3 4 5 6);
326def_access_tuple!($ access_tuple8, 1 2 3 4 5 6 7);
327def_access_tuple!($ access_tuple9, 1 2 3 4 5 6 7 8);
328def_access_tuple!($ access_tupleA, 1 2 3 4 5 6 7 8 9);
329
330#[must_use = "strategies do nothing unless used"]
339#[derive(Clone, Copy, Debug)]
340pub struct TupleUnion<T>(T);
341
342impl<T> TupleUnion<T> {
343 pub fn new(tuple: T) -> Self {
354 TupleUnion(tuple)
355 }
356}
357
358macro_rules! tuple_union {
359 ($($gen:ident $ix:tt)*) => {
360 impl<A : Strategy, $($gen: Strategy<Value = A::Value>),*>
361 Strategy for TupleUnion<(WA<A>, $(WA<$gen>),*)> {
362 type Tree = TupleUnionValueTree<
363 (LazyValueTree<A>, $(Option<LazyValueTree<$gen>>),*)>;
364 type Value = A::Value;
365
366 fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
367 let weights = [((self.0).0).0, $(((self.0).$ix).0),*];
368 let pick = pick_weighted(runner, weights.iter().cloned(),
369 weights.iter().cloned());
370
371 Ok(TupleUnionValueTree {
372 options: (
373 if 0 == pick {
374 LazyValueTree::new_initialized(
375 ((self.0).0).1.new_tree(runner)?)
376 } else {
377 LazyValueTree::new(
378 Arc::clone(&((self.0).0).1), runner)
379 },
380 $(
381 if $ix == pick {
382 Some(LazyValueTree::new_initialized(
383 ((self.0).$ix).1.new_tree(runner)?))
384 } else if $ix < pick {
385 Some(LazyValueTree::new(
386 Arc::clone(&((self.0).$ix).1), runner))
387 } else {
388 None
389 }),*),
390 pick: pick,
391 min_pick: 0,
392 prev_pick: None,
393 })
394 }
395 }
396 }
397}
398
399tuple_union!(B 1);
400tuple_union!(B 1 C 2);
401tuple_union!(B 1 C 2 D 3);
402tuple_union!(B 1 C 2 D 3 E 4);
403tuple_union!(B 1 C 2 D 3 E 4 F 5);
404tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6);
405tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6 H 7);
406tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6 H 7 I 8);
407tuple_union!(B 1 C 2 D 3 E 4 F 5 G 6 H 7 I 8 J 9);
408
409#[derive(Clone, Copy, Debug)]
411pub struct TupleUnionValueTree<T> {
412 options: T,
413 pick: usize,
414 min_pick: usize,
415 prev_pick: Option<usize>,
416}
417
418macro_rules! value_tree_tuple {
419 ($access:ident, $($gen:ident)*) => {
420 impl<A : Strategy, $($gen: Strategy<Value = A::Value>),*> ValueTree
421 for TupleUnionValueTree<
422 (LazyValueTree<A>, $(Option<LazyValueTree<$gen>>),*)
423 > {
424 lazy_union_value_tree_body!(A::Value, $access);
425 }
426 }
427}
428
429value_tree_tuple!(access_tuple2, B);
430value_tree_tuple!(access_tuple3, B C);
431value_tree_tuple!(access_tuple4, B C D);
432value_tree_tuple!(access_tuple5, B C D E);
433value_tree_tuple!(access_tuple6, B C D E F);
434value_tree_tuple!(access_tuple7, B C D E F G);
435value_tree_tuple!(access_tuple8, B C D E F G H);
436value_tree_tuple!(access_tuple9, B C D E F G H I);
437value_tree_tuple!(access_tupleA, B C D E F G H I J);
438
439const WEIGHT_BASE: u32 = 0x8000_0000;
440
441pub fn float_to_weight(f: f64) -> (u32, u32) {
456 assert!(f > 0.0 && f < 1.0, "Invalid probability: {}", f);
457
458 let pos = max(
460 1,
461 min(WEIGHT_BASE - 1, (f * f64::from(WEIGHT_BASE)).round() as u32),
462 );
463 let neg = WEIGHT_BASE - pos;
464
465 (pos, neg)
466}
467
468#[cfg(test)]
469mod test {
470 use super::*;
471 use crate::strategy::just::Just;
472
473 #[cfg(feature = "std")]
478 #[test]
479 fn test_union() {
480 let input = (10u32..20u32).prop_union(30u32..40u32);
481 let mut passed = 0;
486 let mut converged_low = 0;
487 let mut converged_high = 0;
488 let mut runner = TestRunner::deterministic();
489 for _ in 0..256 {
490 let case = input.new_tree(&mut runner).unwrap();
491 let result = runner.run_one(case, |v| {
492 prop_assert!(v < 15);
493 Ok(())
494 });
495
496 match result {
497 Ok(true) => passed += 1,
498 Err(TestError::Fail(_, 15)) => converged_low += 1,
499 Err(TestError::Fail(_, 30)) => converged_high += 1,
500 e => panic!("Unexpected result: {:?}", e),
501 }
502 }
503
504 assert!(passed >= 32 && passed <= 96, "Bad passed count: {}", passed);
505 assert!(
506 converged_low >= 32 && converged_low <= 160,
507 "Bad converged_low count: {}",
508 converged_low
509 );
510 assert!(
511 converged_high >= 32 && converged_high <= 160,
512 "Bad converged_high count: {}",
513 converged_high
514 );
515 }
516
517 #[test]
518 fn test_union_weighted() {
519 let input = Union::new_weighted(vec![
520 (1, Just(0usize)),
521 (2, Just(1usize)),
522 (1, Just(2usize)),
523 ]);
524
525 let mut counts = [0, 0, 0];
526 let mut runner = TestRunner::deterministic();
527 for _ in 0..65536 {
528 counts[input.new_tree(&mut runner).unwrap().current()] += 1;
529 }
530
531 println!("{:?}", counts);
532 assert!(counts[0] > 0);
533 assert!(counts[2] > 0);
534 assert!(counts[1] > counts[0] * 3 / 2);
535 assert!(counts[1] > counts[2] * 3 / 2);
536 }
537
538 #[test]
539 fn test_union_sanity() {
540 check_strategy_sanity(
541 Union::new_weighted(vec![
542 (1, 0i32..100),
543 (2, 200i32..300),
544 (1, 400i32..500),
545 ]),
546 None,
547 );
548 }
549
550 #[cfg(feature = "std")]
552 #[test]
553 fn test_tuple_union() {
554 let input = TupleUnion::new((
555 (1, Arc::new(10u32..20u32)),
556 (1, Arc::new(30u32..40u32)),
557 ));
558 let mut passed = 0;
563 let mut converged_low = 0;
564 let mut converged_high = 0;
565 let mut runner = TestRunner::deterministic();
566 for _ in 0..256 {
567 let case = input.new_tree(&mut runner).unwrap();
568 let result = runner.run_one(case, |v| {
569 prop_assert!(v < 15);
570 Ok(())
571 });
572
573 match result {
574 Ok(true) => passed += 1,
575 Err(TestError::Fail(_, 15)) => converged_low += 1,
576 Err(TestError::Fail(_, 30)) => converged_high += 1,
577 e => panic!("Unexpected result: {:?}", e),
578 }
579 }
580
581 assert!(passed >= 32 && passed <= 96, "Bad passed count: {}", passed);
582 assert!(
583 converged_low >= 32 && converged_low <= 160,
584 "Bad converged_low count: {}",
585 converged_low
586 );
587 assert!(
588 converged_high >= 32 && converged_high <= 160,
589 "Bad converged_high count: {}",
590 converged_high
591 );
592 }
593
594 #[test]
595 fn test_tuple_union_weighting() {
596 let input = TupleUnion::new((
597 (1, Arc::new(Just(0usize))),
598 (2, Arc::new(Just(1usize))),
599 (1, Arc::new(Just(2usize))),
600 ));
601
602 let mut counts = [0, 0, 0];
603 let mut runner = TestRunner::deterministic();
604 for _ in 0..65536 {
605 counts[input.new_tree(&mut runner).unwrap().current()] += 1;
606 }
607
608 println!("{:?}", counts);
609 assert!(counts[0] > 0);
610 assert!(counts[2] > 0);
611 assert!(counts[1] > counts[0] * 3 / 2);
612 assert!(counts[1] > counts[2] * 3 / 2);
613 }
614
615 #[test]
616 fn test_tuple_union_all_sizes() {
617 let mut runner = TestRunner::deterministic();
618 let r = Arc::new(1i32..10);
619
620 macro_rules! test {
621 ($($part:expr),*) => {{
622 let input = TupleUnion::new((
623 $((1, $part.clone())),*,
624 (1, Arc::new(Just(0i32)))
625 ));
626
627 let mut pass = false;
628 for _ in 0..1024 {
629 if 0 == input.new_tree(&mut runner).unwrap().current() {
630 pass = true;
631 break;
632 }
633 }
634
635 assert!(pass);
636 }}
637 }
638
639 test!(r); test!(r, r); test!(r, r, r); test!(r, r, r, r); test!(r, r, r, r, r); test!(r, r, r, r, r, r); test!(r, r, r, r, r, r, r); test!(r, r, r, r, r, r, r, r); test!(r, r, r, r, r, r, r, r, r); }
649
650 #[test]
651 fn test_tuple_union_sanity() {
652 check_strategy_sanity(
653 TupleUnion::new((
654 (1, Arc::new(0i32..100i32)),
655 (1, Arc::new(200i32..1000i32)),
656 (1, Arc::new(2000i32..3000i32)),
657 )),
658 None,
659 );
660 }
661
662 #[test]
664 fn test_filter_union_sanity() {
665 let filter_strategy = (0u32..256).prop_filter("!%5", |&v| 0 != v % 5);
666 check_strategy_sanity(
667 Union::new(vec![filter_strategy; 8]),
668 Some(filter_sanity_options()),
669 );
670 }
671
672 #[test]
674 fn test_filter_tuple_union_sanity() {
675 let filter_strategy = (0u32..256).prop_filter("!%5", |&v| 0 != v % 5);
676 check_strategy_sanity(
677 TupleUnion::new((
678 (1, Arc::new(filter_strategy.clone())),
679 (1, Arc::new(filter_strategy.clone())),
680 (1, Arc::new(filter_strategy.clone())),
681 (1, Arc::new(filter_strategy.clone())),
682 )),
683 Some(filter_sanity_options()),
684 );
685 }
686
687 fn filter_sanity_options() -> CheckStrategySanityOptions {
688 CheckStrategySanityOptions {
689 strict_complicate_after_simplify: false,
692 error_on_local_rejects: true,
694 ..CheckStrategySanityOptions::default()
695 }
696 }
697}