petgraph/algo/min_spanning_tree.rs
1//! Minimum Spanning Tree algorithms.
2
3use alloc::collections::BinaryHeap;
4
5use hashbrown::{HashMap, HashSet};
6
7use crate::data::Element;
8use crate::prelude::*;
9use crate::scored::MinScored;
10use crate::unionfind::UnionFind;
11use crate::visit::{Data, IntoEdges, IntoNodeReferences, NodeRef};
12use crate::visit::{IntoEdgeReferences, NodeIndexable};
13
14/// \[Generic\] Compute a *minimum spanning tree* of a graph.
15///
16/// The input graph is treated as if undirected.
17///
18/// Using Kruskal's algorithm with runtime **O(|E| log |E|)**. We actually
19/// return a minimum spanning forest, i.e. a minimum spanning tree for each connected
20/// component of the graph.
21///
22/// The resulting graph has all the vertices of the input graph (with identical node indices),
23/// and **|V| - c** edges, where **c** is the number of connected components in `g`.
24///
25/// See also: [`min_spanning_tree_prim`][1] for an implementation using Prim's algorithm.
26///
27/// # Arguments
28/// * `g`: an undirected graph.
29///
30/// # Returns
31/// * [`MinSpanningTree`]: an iterator producing a minimum spanning forest of a graph.
32/// Use `from_elements` to create a graph from the resulting iterator.
33///
34/// # Complexity
35/// * Time complexity: **O(|E| log |E|)**.
36/// * Auxiliary space: **O(|V| + |E|)**.
37///
38/// where **|V|** is the number of nodes and **|E|** is the number of edges.
39///
40/// [1]: fn.min_spanning_tree_prim.html
41///
42/// # Example
43/// ```rust
44/// use petgraph::Graph;
45/// use petgraph::algo::min_spanning_tree;
46/// use petgraph::data::FromElements;
47/// use petgraph::graph::UnGraph;
48///
49/// let mut g = Graph::new_undirected();
50/// let a = g.add_node(());
51/// let b = g.add_node(());
52/// let c = g.add_node(());
53/// let d = g.add_node(());
54/// let e = g.add_node(());
55/// let f = g.add_node(());
56/// g.extend_with_edges(&[
57/// (0, 1, 2.0),
58/// (0, 3, 4.0),
59/// (1, 2, 1.0),
60/// (1, 5, 7.0),
61/// (2, 4, 5.0),
62/// (4, 5, 1.0),
63/// (3, 4, 1.0),
64/// ]);
65///
66/// // The graph looks like this:
67/// // 2 1
68/// // a ----- b ----- c
69/// // | 4 | 7 |
70/// // d f | 5
71/// // | 1 | 1 |
72/// // \------ e ------/
73///
74/// let mst = UnGraph::<_, _>::from_elements(min_spanning_tree(&g));
75/// assert_eq!(g.node_count(), mst.node_count());
76/// assert_eq!(mst.node_count() - 1, mst.edge_count());
77///
78/// // The resulting minimum spanning tree looks like this:
79/// // 2 1
80/// // a ----- b ----- c
81/// // | 4
82/// // d f
83/// // | 1 | 1
84/// // \------ e
85///
86/// let mut edge_weight_vec = mst.edge_weights().cloned().collect::<Vec<_>>();
87/// edge_weight_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
88/// assert_eq!(edge_weight_vec , vec![1.0, 1.0, 1.0, 2.0, 4.0]);
89/// ```
90pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
91where
92 G::NodeWeight: Clone,
93 G::EdgeWeight: Clone + PartialOrd,
94 G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
95{
96 // Initially each vertex is its own disjoint subgraph, track the connectedness
97 // of the pre-MST with a union & find datastructure.
98 let subgraphs = UnionFind::new(g.node_bound());
99
100 let edges = g.edge_references();
101 let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
102 for edge in edges {
103 sort_edges.push(MinScored(
104 edge.weight().clone(),
105 (edge.source(), edge.target()),
106 ));
107 }
108
109 MinSpanningTree {
110 graph: g,
111 node_ids: Some(g.node_references()),
112 subgraphs,
113 sort_edges,
114 node_map: HashMap::new(),
115 node_count: 0,
116 }
117}
118
119/// An iterator producing a minimum spanning forest of a graph.
120/// It will first iterate all Node elements from original graph,
121/// then iterate Edge elements from computed minimum spanning forest.
122#[derive(Debug, Clone)]
123pub struct MinSpanningTree<G>
124where
125 G: Data + IntoNodeReferences,
126{
127 graph: G,
128 node_ids: Option<G::NodeReferences>,
129 subgraphs: UnionFind<usize>,
130 #[allow(clippy::type_complexity)]
131 sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
132 node_map: HashMap<usize, usize>,
133 node_count: usize,
134}
135
136impl<G> Iterator for MinSpanningTree<G>
137where
138 G: IntoNodeReferences + NodeIndexable,
139 G::NodeWeight: Clone,
140 G::EdgeWeight: PartialOrd,
141{
142 type Item = Element<G::NodeWeight, G::EdgeWeight>;
143
144 fn next(&mut self) -> Option<Self::Item> {
145 let g = self.graph;
146 if let Some(ref mut iter) = self.node_ids {
147 if let Some(node) = iter.next() {
148 self.node_map.insert(g.to_index(node.id()), self.node_count);
149 self.node_count += 1;
150 return Some(Element::Node {
151 weight: node.weight().clone(),
152 });
153 }
154 }
155 self.node_ids = None;
156
157 // Kruskal's algorithm.
158 // Algorithm is this:
159 //
160 // 1. Create a pre-MST with all the vertices and no edges.
161 // 2. Repeat:
162 //
163 // a. Remove the shortest edge from the original graph.
164 // b. If the edge connects two disjoint trees in the pre-MST,
165 // add the edge.
166 while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
167 // check if the edge would connect two disjoint parts
168 let (a_index, b_index) = (g.to_index(a), g.to_index(b));
169 if self.subgraphs.union(a_index, b_index) {
170 let (&a_order, &b_order) =
171 match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
172 (Some(a_id), Some(b_id)) => (a_id, b_id),
173 _ => panic!("Edge references unknown node"),
174 };
175 return Some(Element::Edge {
176 source: a_order,
177 target: b_order,
178 weight: score,
179 });
180 }
181 }
182 None
183 }
184}
185
186/// \[Generic\] Compute a *minimum spanning tree* of a graph using Prim's algorithm.
187///
188/// Graph is treated as if undirected. The computed minimum spanning tree can be wrong
189/// if this is not true.
190///
191/// Graph is treated as if connected (has only 1 component). Otherwise, the resulting
192/// graph will only contain edges for an arbitrary minimum spanning tree for a single component.
193///
194/// The resulting graph has all the vertices of the input graph (with identical node indices),
195/// and **|V| - 1** edges if input graph is connected, and |W| edges if disconnected, where |W| < |V| - 1.
196///
197/// See also: [`min_spanning_tree`][1] for an implementation using Kruskal's algorithm and support for minimum spanning forest.
198///
199/// # Arguments
200/// * `g`: an undirected graph.
201///
202/// # Returns
203/// * [`MinSpanningTreePrim`]: an iterator producing a minimum spanning tree of a graph.
204/// Use `from_elements` to create a graph from the resulting iterator.
205///
206/// # Complexity
207/// * Time complexity: **O(|E| log |V|)**.
208/// * Auxiliary space: **O(|V| + |E|)**.
209///
210/// where **|V|** is the number of nodes and **|E|** is the number of edges.
211///
212/// [1]: fn.min_spanning_tree.html
213///
214/// # Example
215/// ```rust
216/// use petgraph::Graph;
217/// use petgraph::algo::min_spanning_tree_prim;
218/// use petgraph::data::FromElements;
219/// use petgraph::graph::UnGraph;
220///
221/// let mut g = Graph::new_undirected();
222/// let a = g.add_node(());
223/// let b = g.add_node(());
224/// let c = g.add_node(());
225/// let d = g.add_node(());
226/// let e = g.add_node(());
227/// let f = g.add_node(());
228/// g.extend_with_edges(&[
229/// (0, 1, 2.0),
230/// (0, 3, 4.0),
231/// (1, 2, 1.0),
232/// (1, 5, 7.0),
233/// (2, 4, 5.0),
234/// (4, 5, 1.0),
235/// (3, 4, 1.0),
236/// ]);
237///
238/// // The graph looks like this:
239/// // 2 1
240/// // a ----- b ----- c
241/// // | 4 | 7 |
242/// // d f | 5
243/// // | 1 | 1 |
244/// // \------ e ------/
245///
246/// let mst = UnGraph::<_, _>::from_elements(min_spanning_tree_prim(&g));
247/// assert_eq!(g.node_count(), mst.node_count());
248/// assert_eq!(mst.node_count() - 1, mst.edge_count());
249///
250/// // The resulting minimum spanning tree looks like this:
251/// // 2 1
252/// // a ----- b ----- c
253/// // | 4
254/// // d f
255/// // | 1 | 1
256/// // \------ e
257///
258/// let mut edge_weight_vec = mst.edge_weights().cloned().collect::<Vec<_>>();
259/// edge_weight_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
260/// assert_eq!(edge_weight_vec , vec![1.0, 1.0, 1.0, 2.0, 4.0]);
261/// ```
262pub fn min_spanning_tree_prim<G>(g: G) -> MinSpanningTreePrim<G>
263where
264 G::EdgeWeight: PartialOrd,
265 G: IntoNodeReferences + IntoEdgeReferences,
266{
267 let sort_edges = BinaryHeap::with_capacity(g.edge_references().size_hint().0);
268 let nodes_taken = HashSet::with_capacity(g.node_references().size_hint().0);
269 let initial_node = g.node_references().next();
270
271 MinSpanningTreePrim {
272 graph: g,
273 node_ids: Some(g.node_references()),
274 node_map: HashMap::new(),
275 node_count: 0,
276 sort_edges,
277 nodes_taken,
278 initial_node,
279 }
280}
281
282/// An iterator producing a minimum spanning tree of a graph.
283/// It will first iterate all Node elements from original graph,
284/// then iterate Edge elements from computed minimum spanning tree.
285#[derive(Debug, Clone)]
286pub struct MinSpanningTreePrim<G>
287where
288 G: IntoNodeReferences,
289{
290 graph: G,
291 node_ids: Option<G::NodeReferences>,
292 node_map: HashMap<usize, usize>,
293 node_count: usize,
294 #[allow(clippy::type_complexity)]
295 sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
296 nodes_taken: HashSet<usize>,
297 initial_node: Option<G::NodeRef>,
298}
299
300impl<G> Iterator for MinSpanningTreePrim<G>
301where
302 G: IntoNodeReferences + IntoEdges + NodeIndexable,
303 G::NodeWeight: Clone,
304 G::EdgeWeight: Clone + PartialOrd,
305{
306 type Item = Element<G::NodeWeight, G::EdgeWeight>;
307
308 fn next(&mut self) -> Option<Self::Item> {
309 // Iterate through Node elements
310 let g = self.graph;
311 if let Some(ref mut iter) = self.node_ids {
312 if let Some(node) = iter.next() {
313 self.node_map.insert(g.to_index(node.id()), self.node_count);
314 self.node_count += 1;
315 return Some(Element::Node {
316 weight: node.weight().clone(),
317 });
318 }
319 }
320 self.node_ids = None;
321
322 // Bootstrap Prim's algorithm to find MST Edge elements.
323 // Mark initial node as taken and add its edges to priority queue.
324 if let Some(initial_node) = self.initial_node {
325 let initial_node_index = g.to_index(initial_node.id());
326 self.nodes_taken.insert(initial_node_index);
327
328 let initial_edges = g.edges(initial_node.id());
329 for edge in initial_edges {
330 self.sort_edges.push(MinScored(
331 edge.weight().clone(),
332 (edge.source(), edge.target()),
333 ));
334 }
335 };
336 self.initial_node = None;
337
338 // Clear edges queue if all nodes were already included in MST.
339 if self.nodes_taken.len() == self.node_count {
340 self.sort_edges.clear();
341 };
342
343 // Prim's algorithm:
344 // Iterate through Edge elements, adding an edge to the MST iff some of it's nodes are not part of MST yet.
345 while let Some(MinScored(score, (source, target))) = self.sort_edges.pop() {
346 let (source_index, target_index) = (g.to_index(source), g.to_index(target));
347
348 if self.nodes_taken.contains(&target_index) {
349 continue;
350 }
351
352 self.nodes_taken.insert(target_index);
353 for edge in g.edges(target) {
354 self.sort_edges.push(MinScored(
355 edge.weight().clone(),
356 (edge.source(), edge.target()),
357 ));
358 }
359
360 let (&source_order, &target_order) = match (
361 self.node_map.get(&source_index),
362 self.node_map.get(&target_index),
363 ) {
364 (Some(source_order), Some(target_order)) => (source_order, target_order),
365 _ => panic!("Edge references unknown node"),
366 };
367
368 return Some(Element::Edge {
369 source: source_order,
370 target: target_order,
371 weight: score,
372 });
373 }
374
375 None
376 }
377}