petgraph/algo/
astar.rs

1use std::collections::hash_map::Entry::{Occupied, Vacant};
2use std::collections::{BinaryHeap, HashMap};
3
4use std::hash::Hash;
5
6use crate::scored::MinScored;
7use crate::visit::{EdgeRef, GraphBase, IntoEdges, Visitable};
8
9use crate::algo::Measure;
10
11/// \[Generic\] A* shortest path algorithm.
12///
13/// Computes the shortest path from `start` to `finish`, including the total path cost.
14///
15/// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
16/// given node is the finish node.
17///
18/// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
19/// non-negative.
20///
21/// The function `estimate_cost` should return the estimated cost to the finish for a particular
22/// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
23/// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
24/// must also be non-negative.
25///
26/// The graph should be `Visitable` and implement `IntoEdges`.
27///
28/// # Example
29/// ```
30/// use petgraph::Graph;
31/// use petgraph::algo::astar;
32///
33/// let mut g = Graph::new();
34/// let a = g.add_node((0., 0.));
35/// let b = g.add_node((2., 0.));
36/// let c = g.add_node((1., 1.));
37/// let d = g.add_node((0., 2.));
38/// let e = g.add_node((3., 3.));
39/// let f = g.add_node((4., 2.));
40/// g.extend_with_edges(&[
41///     (a, b, 2),
42///     (a, d, 4),
43///     (b, c, 1),
44///     (b, f, 7),
45///     (c, e, 5),
46///     (e, f, 1),
47///     (d, e, 1),
48/// ]);
49///
50/// // Graph represented with the weight of each edge
51/// // Edges with '*' are part of the optimal path.
52/// //
53/// //     2       1
54/// // a ----- b ----- c
55/// // | 4*    | 7     |
56/// // d       f       | 5
57/// // | 1*    | 1*    |
58/// // \------ e ------/
59///
60/// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
61/// assert_eq!(path, Some((6, vec![a, d, e, f])));
62/// ```
63///
64/// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was
65/// found.
66pub fn astar<G, F, H, K, IsGoal>(
67    graph: G,
68    start: G::NodeId,
69    mut is_goal: IsGoal,
70    mut edge_cost: F,
71    mut estimate_cost: H,
72) -> Option<(K, Vec<G::NodeId>)>
73where
74    G: IntoEdges + Visitable,
75    IsGoal: FnMut(G::NodeId) -> bool,
76    G::NodeId: Eq + Hash,
77    F: FnMut(G::EdgeRef) -> K,
78    H: FnMut(G::NodeId) -> K,
79    K: Measure + Copy,
80{
81    let mut visit_next = BinaryHeap::new();
82    let mut scores = HashMap::new(); // g-values, cost to reach the node
83    let mut estimate_scores = HashMap::new(); // f-values, cost to reach + estimate cost to goal
84    let mut path_tracker = PathTracker::<G>::new();
85
86    let zero_score = K::default();
87    scores.insert(start, zero_score);
88    visit_next.push(MinScored(estimate_cost(start), start));
89
90    while let Some(MinScored(estimate_score, node)) = visit_next.pop() {
91        if is_goal(node) {
92            let path = path_tracker.reconstruct_path_to(node);
93            let cost = scores[&node];
94            return Some((cost, path));
95        }
96
97        // This lookup can be unwrapped without fear of panic since the node was necessarily scored
98        // before adding it to `visit_next`.
99        let node_score = scores[&node];
100
101        match estimate_scores.entry(node) {
102            Occupied(mut entry) => {
103                // If the node has already been visited with an equal or lower score than now, then
104                // we do not need to re-visit it.
105                if *entry.get() <= estimate_score {
106                    continue;
107                }
108                entry.insert(estimate_score);
109            }
110            Vacant(entry) => {
111                entry.insert(estimate_score);
112            }
113        }
114
115        for edge in graph.edges(node) {
116            let next = edge.target();
117            let next_score = node_score + edge_cost(edge);
118
119            match scores.entry(next) {
120                Occupied(mut entry) => {
121                    // No need to add neighbors that we have already reached through a shorter path
122                    // than now.
123                    if *entry.get() <= next_score {
124                        continue;
125                    }
126                    entry.insert(next_score);
127                }
128                Vacant(entry) => {
129                    entry.insert(next_score);
130                }
131            }
132
133            path_tracker.set_predecessor(next, node);
134            let next_estimate_score = next_score + estimate_cost(next);
135            visit_next.push(MinScored(next_estimate_score, next));
136        }
137    }
138
139    None
140}
141
142struct PathTracker<G>
143where
144    G: GraphBase,
145    G::NodeId: Eq + Hash,
146{
147    came_from: HashMap<G::NodeId, G::NodeId>,
148}
149
150impl<G> PathTracker<G>
151where
152    G: GraphBase,
153    G::NodeId: Eq + Hash,
154{
155    fn new() -> PathTracker<G> {
156        PathTracker {
157            came_from: HashMap::new(),
158        }
159    }
160
161    fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
162        self.came_from.insert(node, previous);
163    }
164
165    fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
166        let mut path = vec![last];
167
168        let mut current = last;
169        while let Some(&previous) = self.came_from.get(&current) {
170            path.push(previous);
171            current = previous;
172        }
173
174        path.reverse();
175
176        path
177    }
178}