petgraph/algo/astar.rs
1use std::collections::hash_map::Entry::{Occupied, Vacant};
2use std::collections::{BinaryHeap, HashMap};
3
4use std::hash::Hash;
5
6use crate::scored::MinScored;
7use crate::visit::{EdgeRef, GraphBase, IntoEdges, Visitable};
8
9use crate::algo::Measure;
10
11/// \[Generic\] A* shortest path algorithm.
12///
13/// Computes the shortest path from `start` to `finish`, including the total path cost.
14///
15/// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
16/// given node is the finish node.
17///
18/// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
19/// non-negative.
20///
21/// The function `estimate_cost` should return the estimated cost to the finish for a particular
22/// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
23/// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
24/// must also be non-negative.
25///
26/// The graph should be `Visitable` and implement `IntoEdges`.
27///
28/// # Example
29/// ```
30/// use petgraph::Graph;
31/// use petgraph::algo::astar;
32///
33/// let mut g = Graph::new();
34/// let a = g.add_node((0., 0.));
35/// let b = g.add_node((2., 0.));
36/// let c = g.add_node((1., 1.));
37/// let d = g.add_node((0., 2.));
38/// let e = g.add_node((3., 3.));
39/// let f = g.add_node((4., 2.));
40/// g.extend_with_edges(&[
41/// (a, b, 2),
42/// (a, d, 4),
43/// (b, c, 1),
44/// (b, f, 7),
45/// (c, e, 5),
46/// (e, f, 1),
47/// (d, e, 1),
48/// ]);
49///
50/// // Graph represented with the weight of each edge
51/// // Edges with '*' are part of the optimal path.
52/// //
53/// // 2 1
54/// // a ----- b ----- c
55/// // | 4* | 7 |
56/// // d f | 5
57/// // | 1* | 1* |
58/// // \------ e ------/
59///
60/// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
61/// assert_eq!(path, Some((6, vec![a, d, e, f])));
62/// ```
63///
64/// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was
65/// found.
66pub fn astar<G, F, H, K, IsGoal>(
67 graph: G,
68 start: G::NodeId,
69 mut is_goal: IsGoal,
70 mut edge_cost: F,
71 mut estimate_cost: H,
72) -> Option<(K, Vec<G::NodeId>)>
73where
74 G: IntoEdges + Visitable,
75 IsGoal: FnMut(G::NodeId) -> bool,
76 G::NodeId: Eq + Hash,
77 F: FnMut(G::EdgeRef) -> K,
78 H: FnMut(G::NodeId) -> K,
79 K: Measure + Copy,
80{
81 let mut visit_next = BinaryHeap::new();
82 let mut scores = HashMap::new(); // g-values, cost to reach the node
83 let mut estimate_scores = HashMap::new(); // f-values, cost to reach + estimate cost to goal
84 let mut path_tracker = PathTracker::<G>::new();
85
86 let zero_score = K::default();
87 scores.insert(start, zero_score);
88 visit_next.push(MinScored(estimate_cost(start), start));
89
90 while let Some(MinScored(estimate_score, node)) = visit_next.pop() {
91 if is_goal(node) {
92 let path = path_tracker.reconstruct_path_to(node);
93 let cost = scores[&node];
94 return Some((cost, path));
95 }
96
97 // This lookup can be unwrapped without fear of panic since the node was necessarily scored
98 // before adding it to `visit_next`.
99 let node_score = scores[&node];
100
101 match estimate_scores.entry(node) {
102 Occupied(mut entry) => {
103 // If the node has already been visited with an equal or lower score than now, then
104 // we do not need to re-visit it.
105 if *entry.get() <= estimate_score {
106 continue;
107 }
108 entry.insert(estimate_score);
109 }
110 Vacant(entry) => {
111 entry.insert(estimate_score);
112 }
113 }
114
115 for edge in graph.edges(node) {
116 let next = edge.target();
117 let next_score = node_score + edge_cost(edge);
118
119 match scores.entry(next) {
120 Occupied(mut entry) => {
121 // No need to add neighbors that we have already reached through a shorter path
122 // than now.
123 if *entry.get() <= next_score {
124 continue;
125 }
126 entry.insert(next_score);
127 }
128 Vacant(entry) => {
129 entry.insert(next_score);
130 }
131 }
132
133 path_tracker.set_predecessor(next, node);
134 let next_estimate_score = next_score + estimate_cost(next);
135 visit_next.push(MinScored(next_estimate_score, next));
136 }
137 }
138
139 None
140}
141
142struct PathTracker<G>
143where
144 G: GraphBase,
145 G::NodeId: Eq + Hash,
146{
147 came_from: HashMap<G::NodeId, G::NodeId>,
148}
149
150impl<G> PathTracker<G>
151where
152 G: GraphBase,
153 G::NodeId: Eq + Hash,
154{
155 fn new() -> PathTracker<G> {
156 PathTracker {
157 came_from: HashMap::new(),
158 }
159 }
160
161 fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
162 self.came_from.insert(node, previous);
163 }
164
165 fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
166 let mut path = vec![last];
167
168 let mut current = last;
169 while let Some(&previous) = self.came_from.get(¤t) {
170 path.push(previous);
171 current = previous;
172 }
173
174 path.reverse();
175
176 path
177 }
178}