petgraph/algo/
min_spanning_tree.rs

1//! Minimum Spanning Tree algorithms.
2
3use alloc::collections::BinaryHeap;
4
5use hashbrown::{HashMap, HashSet};
6
7use crate::data::Element;
8use crate::prelude::*;
9use crate::scored::MinScored;
10use crate::unionfind::UnionFind;
11use crate::visit::{Data, IntoEdges, IntoNodeReferences, NodeRef};
12use crate::visit::{IntoEdgeReferences, NodeIndexable};
13
14/// \[Generic\] Compute a *minimum spanning tree* of a graph.
15///
16/// The input graph is treated as if undirected.
17///
18/// Using Kruskal's algorithm with runtime **O(|E| log |E|)**. We actually
19/// return a minimum spanning forest, i.e. a minimum spanning tree for each connected
20/// component of the graph.
21///
22/// The resulting graph has all the vertices of the input graph (with identical node indices),
23/// and **|V| - c** edges, where **c** is the number of connected components in `g`.
24///
25/// See also: [`min_spanning_tree_prim`][1] for an implementation using Prim's algorithm.
26///
27/// # Arguments
28/// * `g`: an undirected graph.
29///
30/// # Returns
31/// * [`MinSpanningTree`]: an iterator producing a minimum spanning forest of a graph.
32///   Use `from_elements` to create a graph from the resulting iterator.
33///
34/// # Complexity
35/// * Time complexity: **O(|E| log |E|)**.
36/// * Auxiliary space: **O(|V| + |E|)**.
37///
38/// where **|V|** is the number of nodes and **|E|** is the number of edges.
39///
40/// [1]: fn.min_spanning_tree_prim.html
41///
42/// # Example
43/// ```rust
44/// use petgraph::Graph;
45/// use petgraph::algo::min_spanning_tree;
46/// use petgraph::data::FromElements;
47/// use petgraph::graph::UnGraph;
48///
49/// let mut g = Graph::new_undirected();
50/// let a = g.add_node(());
51/// let b = g.add_node(());
52/// let c = g.add_node(());
53/// let d = g.add_node(());
54/// let e = g.add_node(());
55/// let f = g.add_node(());
56/// g.extend_with_edges(&[
57///     (0, 1, 2.0),
58///     (0, 3, 4.0),
59///     (1, 2, 1.0),
60///     (1, 5, 7.0),
61///     (2, 4, 5.0),
62///     (4, 5, 1.0),
63///     (3, 4, 1.0),
64/// ]);
65///
66/// // The graph looks like this:
67/// //     2       1
68/// // a ----- b ----- c
69/// // | 4     | 7     |
70/// // d       f       | 5
71/// // | 1     | 1     |
72/// // \------ e ------/
73///
74/// let mst = UnGraph::<_, _>::from_elements(min_spanning_tree(&g));
75/// assert_eq!(g.node_count(), mst.node_count());
76/// assert_eq!(mst.node_count() - 1, mst.edge_count());
77///
78/// // The resulting minimum spanning tree looks like this:
79/// //     2       1
80/// // a ----- b ----- c
81/// // | 4             
82/// // d       f       
83/// // | 1     | 1       
84/// // \------ e
85///
86/// let mut edge_weight_vec = mst.edge_weights().cloned().collect::<Vec<_>>();
87/// edge_weight_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
88/// assert_eq!(edge_weight_vec , vec![1.0, 1.0, 1.0, 2.0, 4.0]);
89/// ```
90pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
91where
92    G::NodeWeight: Clone,
93    G::EdgeWeight: Clone + PartialOrd,
94    G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
95{
96    // Initially each vertex is its own disjoint subgraph, track the connectedness
97    // of the pre-MST with a union & find datastructure.
98    let subgraphs = UnionFind::new(g.node_bound());
99
100    let edges = g.edge_references();
101    let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
102    for edge in edges {
103        sort_edges.push(MinScored(
104            edge.weight().clone(),
105            (edge.source(), edge.target()),
106        ));
107    }
108
109    MinSpanningTree {
110        graph: g,
111        node_ids: Some(g.node_references()),
112        subgraphs,
113        sort_edges,
114        node_map: HashMap::new(),
115        node_count: 0,
116    }
117}
118
119/// An iterator producing a minimum spanning forest of a graph.
120/// It will first iterate all Node elements from original graph,
121/// then iterate Edge elements from computed minimum spanning forest.
122#[derive(Debug, Clone)]
123pub struct MinSpanningTree<G>
124where
125    G: Data + IntoNodeReferences,
126{
127    graph: G,
128    node_ids: Option<G::NodeReferences>,
129    subgraphs: UnionFind<usize>,
130    #[allow(clippy::type_complexity)]
131    sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
132    node_map: HashMap<usize, usize>,
133    node_count: usize,
134}
135
136impl<G> Iterator for MinSpanningTree<G>
137where
138    G: IntoNodeReferences + NodeIndexable,
139    G::NodeWeight: Clone,
140    G::EdgeWeight: PartialOrd,
141{
142    type Item = Element<G::NodeWeight, G::EdgeWeight>;
143
144    fn next(&mut self) -> Option<Self::Item> {
145        let g = self.graph;
146        if let Some(ref mut iter) = self.node_ids {
147            if let Some(node) = iter.next() {
148                self.node_map.insert(g.to_index(node.id()), self.node_count);
149                self.node_count += 1;
150                return Some(Element::Node {
151                    weight: node.weight().clone(),
152                });
153            }
154        }
155        self.node_ids = None;
156
157        // Kruskal's algorithm.
158        // Algorithm is this:
159        //
160        // 1. Create a pre-MST with all the vertices and no edges.
161        // 2. Repeat:
162        //
163        //  a. Remove the shortest edge from the original graph.
164        //  b. If the edge connects two disjoint trees in the pre-MST,
165        //     add the edge.
166        while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
167            // check if the edge would connect two disjoint parts
168            let (a_index, b_index) = (g.to_index(a), g.to_index(b));
169            if self.subgraphs.union(a_index, b_index) {
170                let (&a_order, &b_order) =
171                    match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
172                        (Some(a_id), Some(b_id)) => (a_id, b_id),
173                        _ => panic!("Edge references unknown node"),
174                    };
175                return Some(Element::Edge {
176                    source: a_order,
177                    target: b_order,
178                    weight: score,
179                });
180            }
181        }
182        None
183    }
184}
185
186/// \[Generic\] Compute a *minimum spanning tree* of a graph using Prim's algorithm.
187///
188/// Graph is treated as if undirected. The computed minimum spanning tree can be wrong
189/// if this is not true.
190///
191/// Graph is treated as if connected (has only 1 component). Otherwise, the resulting
192/// graph will only contain edges for an arbitrary minimum spanning tree for a single component.
193///
194/// The resulting graph has all the vertices of the input graph (with identical node indices),
195/// and **|V| - 1** edges if input graph is connected, and |W| edges if disconnected, where |W| < |V| - 1.
196///
197/// See also: [`min_spanning_tree`][1] for an implementation using Kruskal's algorithm and support for minimum spanning forest.
198///
199/// # Arguments
200/// * `g`: an undirected graph.
201///
202/// # Returns
203/// * [`MinSpanningTreePrim`]: an iterator producing a minimum spanning tree of a graph.
204///   Use `from_elements` to create a graph from the resulting iterator.
205///
206/// # Complexity
207/// * Time complexity: **O(|E| log |V|)**.
208/// * Auxiliary space: **O(|V| + |E|)**.
209///
210/// where **|V|** is the number of nodes and **|E|** is the number of edges.
211///
212/// [1]: fn.min_spanning_tree.html
213///
214/// # Example
215/// ```rust
216/// use petgraph::Graph;
217/// use petgraph::algo::min_spanning_tree_prim;
218/// use petgraph::data::FromElements;
219/// use petgraph::graph::UnGraph;
220///
221/// let mut g = Graph::new_undirected();
222/// let a = g.add_node(());
223/// let b = g.add_node(());
224/// let c = g.add_node(());
225/// let d = g.add_node(());
226/// let e = g.add_node(());
227/// let f = g.add_node(());
228/// g.extend_with_edges(&[
229///     (0, 1, 2.0),
230///     (0, 3, 4.0),
231///     (1, 2, 1.0),
232///     (1, 5, 7.0),
233///     (2, 4, 5.0),
234///     (4, 5, 1.0),
235///     (3, 4, 1.0),
236/// ]);
237///
238/// // The graph looks like this:
239/// //     2       1
240/// // a ----- b ----- c
241/// // | 4     | 7     |
242/// // d       f       | 5
243/// // | 1     | 1     |
244/// // \------ e ------/
245///
246/// let mst = UnGraph::<_, _>::from_elements(min_spanning_tree_prim(&g));
247/// assert_eq!(g.node_count(), mst.node_count());
248/// assert_eq!(mst.node_count() - 1, mst.edge_count());
249///
250/// // The resulting minimum spanning tree looks like this:
251/// //     2       1
252/// // a ----- b ----- c
253/// // | 4
254/// // d       f
255/// // | 1     | 1
256/// // \------ e
257///
258/// let mut edge_weight_vec = mst.edge_weights().cloned().collect::<Vec<_>>();
259/// edge_weight_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
260/// assert_eq!(edge_weight_vec , vec![1.0, 1.0, 1.0, 2.0, 4.0]);
261/// ```
262pub fn min_spanning_tree_prim<G>(g: G) -> MinSpanningTreePrim<G>
263where
264    G::EdgeWeight: PartialOrd,
265    G: IntoNodeReferences + IntoEdgeReferences,
266{
267    let sort_edges = BinaryHeap::with_capacity(g.edge_references().size_hint().0);
268    let nodes_taken = HashSet::with_capacity(g.node_references().size_hint().0);
269    let initial_node = g.node_references().next();
270
271    MinSpanningTreePrim {
272        graph: g,
273        node_ids: Some(g.node_references()),
274        node_map: HashMap::new(),
275        node_count: 0,
276        sort_edges,
277        nodes_taken,
278        initial_node,
279    }
280}
281
282/// An iterator producing a minimum spanning tree of a graph.
283/// It will first iterate all Node elements from original graph,
284/// then iterate Edge elements from computed minimum spanning tree.
285#[derive(Debug, Clone)]
286pub struct MinSpanningTreePrim<G>
287where
288    G: IntoNodeReferences,
289{
290    graph: G,
291    node_ids: Option<G::NodeReferences>,
292    node_map: HashMap<usize, usize>,
293    node_count: usize,
294    #[allow(clippy::type_complexity)]
295    sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
296    nodes_taken: HashSet<usize>,
297    initial_node: Option<G::NodeRef>,
298}
299
300impl<G> Iterator for MinSpanningTreePrim<G>
301where
302    G: IntoNodeReferences + IntoEdges + NodeIndexable,
303    G::NodeWeight: Clone,
304    G::EdgeWeight: Clone + PartialOrd,
305{
306    type Item = Element<G::NodeWeight, G::EdgeWeight>;
307
308    fn next(&mut self) -> Option<Self::Item> {
309        // Iterate through Node elements
310        let g = self.graph;
311        if let Some(ref mut iter) = self.node_ids {
312            if let Some(node) = iter.next() {
313                self.node_map.insert(g.to_index(node.id()), self.node_count);
314                self.node_count += 1;
315                return Some(Element::Node {
316                    weight: node.weight().clone(),
317                });
318            }
319        }
320        self.node_ids = None;
321
322        // Bootstrap Prim's algorithm to find MST Edge elements.
323        // Mark initial node as taken and add its edges to priority queue.
324        if let Some(initial_node) = self.initial_node {
325            let initial_node_index = g.to_index(initial_node.id());
326            self.nodes_taken.insert(initial_node_index);
327
328            let initial_edges = g.edges(initial_node.id());
329            for edge in initial_edges {
330                self.sort_edges.push(MinScored(
331                    edge.weight().clone(),
332                    (edge.source(), edge.target()),
333                ));
334            }
335        };
336        self.initial_node = None;
337
338        // Clear edges queue if all nodes were already included in MST.
339        if self.nodes_taken.len() == self.node_count {
340            self.sort_edges.clear();
341        };
342
343        // Prim's algorithm:
344        // Iterate through Edge elements, adding an edge to the MST iff some of it's nodes are not part of MST yet.
345        while let Some(MinScored(score, (source, target))) = self.sort_edges.pop() {
346            let (source_index, target_index) = (g.to_index(source), g.to_index(target));
347
348            if self.nodes_taken.contains(&target_index) {
349                continue;
350            }
351
352            self.nodes_taken.insert(target_index);
353            for edge in g.edges(target) {
354                self.sort_edges.push(MinScored(
355                    edge.weight().clone(),
356                    (edge.source(), edge.target()),
357                ));
358            }
359
360            let (&source_order, &target_order) = match (
361                self.node_map.get(&source_index),
362                self.node_map.get(&target_index),
363            ) {
364                (Some(source_order), Some(target_order)) => (source_order, target_order),
365                _ => panic!("Edge references unknown node"),
366            };
367
368            return Some(Element::Edge {
369                source: source_order,
370                target: target_order,
371                weight: score,
372            });
373        }
374
375        None
376    }
377}