petgraph/algo/
astar.rs

1use alloc::{collections::BinaryHeap, vec, vec::Vec};
2use core::hash::Hash;
3
4use hashbrown::hash_map::{
5    Entry::{Occupied, Vacant},
6    HashMap,
7};
8
9use crate::algo::Measure;
10use crate::scored::MinScored;
11use crate::visit::{EdgeRef, GraphBase, IntoEdges, Visitable};
12
13/// \[Generic\] A* shortest path algorithm.
14///
15/// Computes the shortest path from `start` to `finish`, including the total path cost.
16///
17/// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
18/// given node is the finish node.
19///
20/// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
21/// non-negative.
22///
23/// The function `estimate_cost` should return the estimated cost to the finish for a particular
24/// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
25/// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
26/// must also be non-negative.
27///
28/// # Arguments
29/// * `graph`: weighted graph.
30/// * `start`: the start node.
31/// * `is_goal`: the callback defines the goal node.
32/// * `edge_cost`: closure that returns cost of a particular edge.
33/// * `estimate_cost`: closure that returns the estimated cost to the finish for particular node.
34///
35/// # Returns
36/// * `Some(K, Vec<G::NodeId>)` - the total cost and path from start to finish, if one was found.
37/// * `None` - if such a path was not found.
38///
39/// # Complexity
40/// The time complexity largely depends on the heuristic used. Feel free to contribute and provide the exact time complexity :)
41///
42/// With a trivial heuristic, the algorithm will behave like [`fn@crate::algo::dijkstra`].
43///
44/// # Example
45/// ```
46/// use petgraph::Graph;
47/// use petgraph::algo::astar;
48///
49/// let mut g = Graph::new();
50/// let a = g.add_node((0., 0.));
51/// let b = g.add_node((2., 0.));
52/// let c = g.add_node((1., 1.));
53/// let d = g.add_node((0., 2.));
54/// let e = g.add_node((3., 3.));
55/// let f = g.add_node((4., 2.));
56/// g.extend_with_edges(&[
57///     (a, b, 2),
58///     (a, d, 4),
59///     (b, c, 1),
60///     (b, f, 7),
61///     (c, e, 5),
62///     (e, f, 1),
63///     (d, e, 1),
64/// ]);
65///
66/// // Graph represented with the weight of each edge
67/// // Edges with '*' are part of the optimal path.
68/// //
69/// //     2       1
70/// // a ----- b ----- c
71/// // | 4*    | 7     |
72/// // d       f       | 5
73/// // | 1*    | 1*    |
74/// // \------ e ------/
75///
76/// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
77/// assert_eq!(path, Some((6, vec![a, d, e, f])));
78/// ```
79pub fn astar<G, F, H, K, IsGoal>(
80    graph: G,
81    start: G::NodeId,
82    mut is_goal: IsGoal,
83    mut edge_cost: F,
84    mut estimate_cost: H,
85) -> Option<(K, Vec<G::NodeId>)>
86where
87    G: IntoEdges + Visitable,
88    IsGoal: FnMut(G::NodeId) -> bool,
89    G::NodeId: Eq + Hash,
90    F: FnMut(G::EdgeRef) -> K,
91    H: FnMut(G::NodeId) -> K,
92    K: Measure + Copy,
93{
94    let mut visit_next = BinaryHeap::new();
95    let mut scores = HashMap::new(); // g-values, cost to reach the node
96    let mut estimate_scores = HashMap::new(); // f-values, cost to reach + estimate cost to goal
97    let mut path_tracker = PathTracker::<G>::new();
98
99    let zero_score = K::default();
100    scores.insert(start, zero_score);
101    visit_next.push(MinScored(estimate_cost(start), start));
102
103    while let Some(MinScored(estimate_score, node)) = visit_next.pop() {
104        if is_goal(node) {
105            let path = path_tracker.reconstruct_path_to(node);
106            let cost = scores[&node];
107            return Some((cost, path));
108        }
109
110        // This lookup can be unwrapped without fear of panic since the node was necessarily scored
111        // before adding it to `visit_next`.
112        let node_score = scores[&node];
113
114        match estimate_scores.entry(node) {
115            Occupied(mut entry) => {
116                // If the node has already been visited with an equal or lower score than now, then
117                // we do not need to re-visit it.
118                if *entry.get() <= estimate_score {
119                    continue;
120                }
121                entry.insert(estimate_score);
122            }
123            Vacant(entry) => {
124                entry.insert(estimate_score);
125            }
126        }
127
128        for edge in graph.edges(node) {
129            let next = edge.target();
130            let next_score = node_score + edge_cost(edge);
131
132            match scores.entry(next) {
133                Occupied(mut entry) => {
134                    // No need to add neighbors that we have already reached through a shorter path
135                    // than now.
136                    if *entry.get() <= next_score {
137                        continue;
138                    }
139                    entry.insert(next_score);
140                }
141                Vacant(entry) => {
142                    entry.insert(next_score);
143                }
144            }
145
146            path_tracker.set_predecessor(next, node);
147            let next_estimate_score = next_score + estimate_cost(next);
148            visit_next.push(MinScored(next_estimate_score, next));
149        }
150    }
151
152    None
153}
154
155struct PathTracker<G>
156where
157    G: GraphBase,
158    G::NodeId: Eq + Hash,
159{
160    came_from: HashMap<G::NodeId, G::NodeId>,
161}
162
163impl<G> PathTracker<G>
164where
165    G: GraphBase,
166    G::NodeId: Eq + Hash,
167{
168    fn new() -> PathTracker<G> {
169        PathTracker {
170            came_from: HashMap::new(),
171        }
172    }
173
174    fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
175        self.came_from.insert(node, previous);
176    }
177
178    fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
179        let mut path = vec![last];
180
181        let mut current = last;
182        while let Some(&previous) = self.came_from.get(&current) {
183            path.push(previous);
184            current = previous;
185        }
186
187        path.reverse();
188
189        path
190    }
191}