petgraph/
csr.rs

1//! Compressed Sparse Row (CSR) is a sparse adjacency matrix graph.
2
3use std::cmp::{max, Ordering};
4use std::iter::{Enumerate, Zip};
5use std::marker::PhantomData;
6use std::ops::{Index, IndexMut, Range};
7use std::slice::Windows;
8
9use crate::visit::{
10    Data, EdgeCount, EdgeRef, GetAdjacencyMatrix, GraphBase, GraphProp, IntoEdgeReferences,
11    IntoEdges, IntoNeighbors, IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable,
12    NodeCount, NodeIndexable, Visitable,
13};
14
15use crate::util::zip;
16
17#[doc(no_inline)]
18pub use crate::graph::{DefaultIx, IndexType};
19
20use crate::{Directed, EdgeType, IntoWeightedEdge};
21
22/// Csr node index type, a plain integer.
23pub type NodeIndex<Ix = DefaultIx> = Ix;
24/// Csr edge index type, a plain integer.
25pub type EdgeIndex = usize;
26
27const BINARY_SEARCH_CUTOFF: usize = 32;
28
29/// Compressed Sparse Row ([`CSR`]) is a sparse adjacency matrix graph.
30///
31/// `CSR` is parameterized over:
32///
33/// - Associated data `N` for nodes and `E` for edges, called *weights*.
34///   The associated data can be of arbitrary type.
35/// - Edge type `Ty` that determines whether the graph edges are directed or undirected.
36/// - Index type `Ix`, which determines the maximum size of the graph.
37///
38///
39/// Using **O(|E| + |V|)** space.
40///
41/// Self loops are allowed, no parallel edges.
42///
43/// Fast iteration of the outgoing edges of a vertex.
44///
45/// [`CSR`]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
46#[derive(Debug)]
47pub struct Csr<N = (), E = (), Ty = Directed, Ix = DefaultIx> {
48    /// Column of next edge
49    column: Vec<NodeIndex<Ix>>,
50    /// weight of each edge; lock step with column
51    edges: Vec<E>,
52    /// Index of start of row Always node_count + 1 long.
53    /// Last element is always equal to column.len()
54    row: Vec<usize>,
55    node_weights: Vec<N>,
56    edge_count: usize,
57    ty: PhantomData<Ty>,
58}
59
60impl<N, E, Ty, Ix> Default for Csr<N, E, Ty, Ix>
61where
62    Ty: EdgeType,
63    Ix: IndexType,
64{
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl<N: Clone, E: Clone, Ty, Ix: Clone> Clone for Csr<N, E, Ty, Ix> {
71    fn clone(&self) -> Self {
72        Csr {
73            column: self.column.clone(),
74            edges: self.edges.clone(),
75            row: self.row.clone(),
76            node_weights: self.node_weights.clone(),
77            edge_count: self.edge_count,
78            ty: self.ty,
79        }
80    }
81}
82
83impl<N, E, Ty, Ix> Csr<N, E, Ty, Ix>
84where
85    Ty: EdgeType,
86    Ix: IndexType,
87{
88    /// Create an empty `Csr`.
89    pub fn new() -> Self {
90        Csr {
91            column: vec![],
92            edges: vec![],
93            row: vec![0; 1],
94            node_weights: vec![],
95            edge_count: 0,
96            ty: PhantomData,
97        }
98    }
99
100    /// Create a new `Csr` with `n` nodes. `N` must implement [`Default`] for the weight of each node.
101    ///
102    /// [`Default`]: https://doc.rust-lang.org/nightly/core/default/trait.Default.html
103    ///
104    /// # Example
105    /// ```rust
106    /// use petgraph::csr::Csr;
107    /// use petgraph::prelude::*;
108    ///
109    /// let graph = Csr::<u8,()>::with_nodes(5);
110    /// assert_eq!(graph.node_count(),5);
111    /// assert_eq!(graph.edge_count(),0);
112    ///
113    /// assert_eq!(graph[0],0);
114    /// assert_eq!(graph[4],0);
115    /// ```
116    pub fn with_nodes(n: usize) -> Self
117    where
118        N: Default,
119    {
120        Csr {
121            column: Vec::new(),
122            edges: Vec::new(),
123            row: vec![0; n + 1],
124            node_weights: (0..n).map(|_| N::default()).collect(),
125            edge_count: 0,
126            ty: PhantomData,
127        }
128    }
129}
130
131/// Csr creation error: edges were not in sorted order.
132#[derive(Clone, Debug)]
133pub struct EdgesNotSorted {
134    #[allow(unused)]
135    first_error: (usize, usize),
136}
137
138impl<N, E, Ix> Csr<N, E, Directed, Ix>
139where
140    Ix: IndexType,
141{
142    /// Create a new `Csr` from a sorted sequence of edges
143    ///
144    /// Edges **must** be sorted and unique, where the sort order is the default
145    /// order for the pair *(u, v)* in Rust (*u* has priority).
146    ///
147    /// Computes in **O(|E| + |V|)** time.
148    /// # Example
149    /// ```rust
150    /// use petgraph::csr::Csr;
151    /// use petgraph::prelude::*;
152    ///
153    /// let graph = Csr::<(),()>::from_sorted_edges(&[
154    ///                     (0, 1), (0, 2),
155    ///                     (1, 0), (1, 2), (1, 3),
156    ///                     (2, 0),
157    ///                     (3, 1),
158    /// ]);
159    /// ```
160    pub fn from_sorted_edges<Edge>(edges: &[Edge]) -> Result<Self, EdgesNotSorted>
161    where
162        Edge: Clone + IntoWeightedEdge<E, NodeId = NodeIndex<Ix>>,
163        N: Default,
164    {
165        let max_node_id = match edges
166            .iter()
167            .map(|edge| {
168                let (x, y, _) = edge.clone().into_weighted_edge();
169                max(x.index(), y.index())
170            })
171            .max()
172        {
173            None => return Ok(Self::with_nodes(0)),
174            Some(x) => x,
175        };
176        let mut self_ = Self::with_nodes(max_node_id + 1);
177        let mut iter = edges.iter().cloned().peekable();
178        {
179            let mut rows = self_.row.iter_mut();
180
181            let mut rstart = 0;
182            let mut last_target;
183            'outer: for (node, r) in (&mut rows).enumerate() {
184                *r = rstart;
185                last_target = None;
186                'inner: loop {
187                    if let Some(edge) = iter.peek() {
188                        let (n, m, weight) = edge.clone().into_weighted_edge();
189                        // check that the edges are in increasing sequence
190                        if node > n.index() {
191                            return Err(EdgesNotSorted {
192                                first_error: (n.index(), m.index()),
193                            });
194                        }
195                        /*
196                        debug_assert!(node <= n.index(),
197                                      concat!("edges are not sorted, ",
198                                              "failed assertion source {:?} <= {:?} ",
199                                              "for edge {:?}"),
200                                      node, n, (n, m));
201                                      */
202                        if n.index() != node {
203                            break 'inner;
204                        }
205                        // check that the edges are in increasing sequence
206                        /*
207                        debug_assert!(last_target.map_or(true, |x| m > x),
208                                      "edges are not sorted, failed assertion {:?} < {:?}",
209                                      last_target, m);
210                                      */
211                        if !last_target.map_or(true, |x| m > x) {
212                            return Err(EdgesNotSorted {
213                                first_error: (n.index(), m.index()),
214                            });
215                        }
216                        last_target = Some(m);
217                        self_.column.push(m);
218                        self_.edges.push(weight);
219                        rstart += 1;
220                    } else {
221                        break 'outer;
222                    }
223                    iter.next();
224                }
225            }
226            for r in rows {
227                *r = rstart;
228            }
229        }
230
231        Ok(self_)
232    }
233}
234
235impl<N, E, Ty, Ix> Csr<N, E, Ty, Ix>
236where
237    Ty: EdgeType,
238    Ix: IndexType,
239{
240    pub fn node_count(&self) -> usize {
241        self.row.len() - 1
242    }
243
244    pub fn edge_count(&self) -> usize {
245        if self.is_directed() {
246            self.column.len()
247        } else {
248            self.edge_count
249        }
250    }
251
252    pub fn is_directed(&self) -> bool {
253        Ty::is_directed()
254    }
255
256    /// Remove all edges
257    pub fn clear_edges(&mut self) {
258        self.column.clear();
259        self.edges.clear();
260        for r in &mut self.row {
261            *r = 0;
262        }
263        if !self.is_directed() {
264            self.edge_count = 0;
265        }
266    }
267
268    /// Adds a new node with the given weight, returning the corresponding node index.
269    pub fn add_node(&mut self, weight: N) -> NodeIndex<Ix> {
270        let i = self.row.len() - 1;
271        self.row.insert(i, self.column.len());
272        self.node_weights.insert(i, weight);
273        Ix::new(i)
274    }
275
276    /// Return `true` if the edge was added
277    ///
278    /// If you add all edges in row-major order, the time complexity
279    /// is **O(|V|·|E|)** for the whole operation.
280    ///
281    /// **Panics** if `a` or `b` are out of bounds.
282    pub fn add_edge(&mut self, a: NodeIndex<Ix>, b: NodeIndex<Ix>, weight: E) -> bool
283    where
284        E: Clone,
285    {
286        let ret = self.add_edge_(a, b, weight.clone());
287        if ret && !self.is_directed() {
288            self.edge_count += 1;
289        }
290        if ret && !self.is_directed() && a != b {
291            let _ret2 = self.add_edge_(b, a, weight);
292            debug_assert_eq!(ret, _ret2);
293        }
294        ret
295    }
296
297    // Return false if the edge already exists
298    fn add_edge_(&mut self, a: NodeIndex<Ix>, b: NodeIndex<Ix>, weight: E) -> bool {
299        assert!(a.index() < self.node_count() && b.index() < self.node_count());
300        // a x b is at (a, b) in the matrix
301
302        // find current range of edges from a
303        let pos = match self.find_edge_pos(a, b) {
304            Ok(_) => return false, /* already exists */
305            Err(i) => i,
306        };
307        self.column.insert(pos, b);
308        self.edges.insert(pos, weight);
309        // update row vector
310        for r in &mut self.row[a.index() + 1..] {
311            *r += 1;
312        }
313        true
314    }
315
316    fn find_edge_pos(&self, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> Result<usize, usize> {
317        let (index, neighbors) = self.neighbors_of(a);
318        if neighbors.len() < BINARY_SEARCH_CUTOFF {
319            for (i, elt) in neighbors.iter().enumerate() {
320                match elt.cmp(&b) {
321                    Ordering::Equal => return Ok(i + index),
322                    Ordering::Greater => return Err(i + index),
323                    Ordering::Less => {}
324                }
325            }
326            Err(neighbors.len() + index)
327        } else {
328            match neighbors.binary_search(&b) {
329                Ok(i) => Ok(i + index),
330                Err(i) => Err(i + index),
331            }
332        }
333    }
334
335    /// Computes in **O(log |V|)** time.
336    ///
337    /// **Panics** if the node `a` does not exist.
338    pub fn contains_edge(&self, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> bool {
339        self.find_edge_pos(a, b).is_ok()
340    }
341
342    fn neighbors_range(&self, a: NodeIndex<Ix>) -> Range<usize> {
343        let index = self.row[a.index()];
344        let end = self
345            .row
346            .get(a.index() + 1)
347            .cloned()
348            .unwrap_or(self.column.len());
349        index..end
350    }
351
352    fn neighbors_of(&self, a: NodeIndex<Ix>) -> (usize, &[Ix]) {
353        let r = self.neighbors_range(a);
354        (r.start, &self.column[r])
355    }
356
357    /// Computes in **O(1)** time.
358    ///
359    /// **Panics** if the node `a` does not exist.
360    pub fn out_degree(&self, a: NodeIndex<Ix>) -> usize {
361        let r = self.neighbors_range(a);
362        r.end - r.start
363    }
364
365    /// Computes in **O(1)** time.
366    ///
367    /// **Panics** if the node `a` does not exist.
368    pub fn neighbors_slice(&self, a: NodeIndex<Ix>) -> &[NodeIndex<Ix>] {
369        self.neighbors_of(a).1
370    }
371
372    /// Computes in **O(1)** time.
373    ///
374    /// **Panics** if the node `a` does not exist.
375    pub fn edges_slice(&self, a: NodeIndex<Ix>) -> &[E] {
376        &self.edges[self.neighbors_range(a)]
377    }
378
379    /// Return an iterator of all edges of `a`.
380    ///
381    /// - `Directed`: Outgoing edges from `a`.
382    /// - `Undirected`: All edges connected to `a`.
383    ///
384    /// **Panics** if the node `a` does not exist.<br>
385    /// Iterator element type is `EdgeReference<E, Ty, Ix>`.
386    pub fn edges(&self, a: NodeIndex<Ix>) -> Edges<E, Ty, Ix> {
387        let r = self.neighbors_range(a);
388        Edges {
389            index: r.start,
390            source: a,
391            iter: zip(&self.column[r.clone()], &self.edges[r]),
392            ty: self.ty,
393        }
394    }
395}
396
397#[derive(Clone, Debug)]
398pub struct Edges<'a, E: 'a, Ty = Directed, Ix: 'a = DefaultIx> {
399    index: usize,
400    source: NodeIndex<Ix>,
401    iter: Zip<SliceIter<'a, NodeIndex<Ix>>, SliceIter<'a, E>>,
402    ty: PhantomData<Ty>,
403}
404
405#[derive(Debug)]
406pub struct EdgeReference<'a, E: 'a, Ty, Ix: 'a = DefaultIx> {
407    index: EdgeIndex,
408    source: NodeIndex<Ix>,
409    target: NodeIndex<Ix>,
410    weight: &'a E,
411    ty: PhantomData<Ty>,
412}
413
414impl<E, Ty, Ix: Copy> Clone for EdgeReference<'_, E, Ty, Ix> {
415    fn clone(&self) -> Self {
416        *self
417    }
418}
419
420impl<E, Ty, Ix: Copy> Copy for EdgeReference<'_, E, Ty, Ix> {}
421
422impl<'a, Ty, E, Ix> EdgeReference<'a, E, Ty, Ix>
423where
424    Ty: EdgeType,
425{
426    /// Access the edge’s weight.
427    ///
428    /// **NOTE** that this method offers a longer lifetime
429    /// than the trait (unfortunately they don't match yet).
430    pub fn weight(&self) -> &'a E {
431        self.weight
432    }
433}
434
435impl<E, Ty, Ix> EdgeRef for EdgeReference<'_, E, Ty, Ix>
436where
437    Ty: EdgeType,
438    Ix: IndexType,
439{
440    type NodeId = NodeIndex<Ix>;
441    type EdgeId = EdgeIndex;
442    type Weight = E;
443
444    fn source(&self) -> Self::NodeId {
445        self.source
446    }
447    fn target(&self) -> Self::NodeId {
448        self.target
449    }
450    fn weight(&self) -> &E {
451        self.weight
452    }
453    fn id(&self) -> Self::EdgeId {
454        self.index
455    }
456}
457
458impl<'a, E, Ty, Ix> Iterator for Edges<'a, E, Ty, Ix>
459where
460    Ty: EdgeType,
461    Ix: IndexType,
462{
463    type Item = EdgeReference<'a, E, Ty, Ix>;
464    fn next(&mut self) -> Option<Self::Item> {
465        self.iter.next().map(move |(&j, w)| {
466            let index = self.index;
467            self.index += 1;
468            EdgeReference {
469                index,
470                source: self.source,
471                target: j,
472                weight: w,
473                ty: PhantomData,
474            }
475        })
476    }
477    fn size_hint(&self) -> (usize, Option<usize>) {
478        self.iter.size_hint()
479    }
480}
481
482impl<N, E, Ty, Ix> Data for Csr<N, E, Ty, Ix>
483where
484    Ty: EdgeType,
485    Ix: IndexType,
486{
487    type NodeWeight = N;
488    type EdgeWeight = E;
489}
490
491impl<'a, N, E, Ty, Ix> IntoEdgeReferences for &'a Csr<N, E, Ty, Ix>
492where
493    Ty: EdgeType,
494    Ix: IndexType,
495{
496    type EdgeRef = EdgeReference<'a, E, Ty, Ix>;
497    type EdgeReferences = EdgeReferences<'a, E, Ty, Ix>;
498    fn edge_references(self) -> Self::EdgeReferences {
499        EdgeReferences {
500            index: 0,
501            source_index: Ix::new(0),
502            edge_ranges: self.row.windows(2).enumerate(),
503            column: &self.column,
504            edges: &self.edges,
505            iter: zip(&[], &[]),
506            ty: self.ty,
507        }
508    }
509}
510
511#[derive(Debug, Clone)]
512pub struct EdgeReferences<'a, E: 'a, Ty, Ix: 'a> {
513    source_index: NodeIndex<Ix>,
514    index: usize,
515    edge_ranges: Enumerate<Windows<'a, usize>>,
516    column: &'a [NodeIndex<Ix>],
517    edges: &'a [E],
518    iter: Zip<SliceIter<'a, NodeIndex<Ix>>, SliceIter<'a, E>>,
519    ty: PhantomData<Ty>,
520}
521
522impl<'a, E, Ty, Ix> Iterator for EdgeReferences<'a, E, Ty, Ix>
523where
524    Ty: EdgeType,
525    Ix: IndexType,
526{
527    type Item = EdgeReference<'a, E, Ty, Ix>;
528    fn next(&mut self) -> Option<Self::Item> {
529        loop {
530            if let Some((&j, w)) = self.iter.next() {
531                let index = self.index;
532                self.index += 1;
533                return Some(EdgeReference {
534                    index,
535                    source: self.source_index,
536                    target: j,
537                    weight: w,
538                    ty: PhantomData,
539                });
540            }
541            if let Some((i, w)) = self.edge_ranges.next() {
542                let a = w[0];
543                let b = w[1];
544                self.iter = zip(&self.column[a..b], &self.edges[a..b]);
545                self.source_index = Ix::new(i);
546            } else {
547                return None;
548            }
549        }
550    }
551}
552
553impl<'a, N, E, Ty, Ix> IntoEdges for &'a Csr<N, E, Ty, Ix>
554where
555    Ty: EdgeType,
556    Ix: IndexType,
557{
558    type Edges = Edges<'a, E, Ty, Ix>;
559    fn edges(self, a: Self::NodeId) -> Self::Edges {
560        self.edges(a)
561    }
562}
563
564impl<N, E, Ty, Ix> GraphBase for Csr<N, E, Ty, Ix>
565where
566    Ty: EdgeType,
567    Ix: IndexType,
568{
569    type NodeId = NodeIndex<Ix>;
570    type EdgeId = EdgeIndex; // index into edges vector
571}
572
573use fixedbitset::FixedBitSet;
574
575impl<N, E, Ty, Ix> Visitable for Csr<N, E, Ty, Ix>
576where
577    Ty: EdgeType,
578    Ix: IndexType,
579{
580    type Map = FixedBitSet;
581    fn visit_map(&self) -> FixedBitSet {
582        FixedBitSet::with_capacity(self.node_count())
583    }
584    fn reset_map(&self, map: &mut Self::Map) {
585        map.clear();
586        map.grow(self.node_count());
587    }
588}
589
590use std::slice::Iter as SliceIter;
591
592#[derive(Clone, Debug)]
593pub struct Neighbors<'a, Ix: 'a = DefaultIx> {
594    iter: SliceIter<'a, NodeIndex<Ix>>,
595}
596
597impl<Ix> Iterator for Neighbors<'_, Ix>
598where
599    Ix: IndexType,
600{
601    type Item = NodeIndex<Ix>;
602
603    fn next(&mut self) -> Option<Self::Item> {
604        self.iter.next().cloned()
605    }
606
607    fn size_hint(&self) -> (usize, Option<usize>) {
608        self.iter.size_hint()
609    }
610}
611
612impl<'a, N, E, Ty, Ix> IntoNeighbors for &'a Csr<N, E, Ty, Ix>
613where
614    Ty: EdgeType,
615    Ix: IndexType,
616{
617    type Neighbors = Neighbors<'a, Ix>;
618
619    /// Return an iterator of all neighbors of `a`.
620    ///
621    /// - `Directed`: Targets of outgoing edges from `a`.
622    /// - `Undirected`: Opposing endpoints of all edges connected to `a`.
623    ///
624    /// **Panics** if the node `a` does not exist.<br>
625    /// Iterator element type is `NodeIndex<Ix>`.
626    fn neighbors(self, a: Self::NodeId) -> Self::Neighbors {
627        Neighbors {
628            iter: self.neighbors_slice(a).iter(),
629        }
630    }
631}
632
633impl<N, E, Ty, Ix> NodeIndexable for Csr<N, E, Ty, Ix>
634where
635    Ty: EdgeType,
636    Ix: IndexType,
637{
638    fn node_bound(&self) -> usize {
639        self.node_count()
640    }
641    fn to_index(&self, a: Self::NodeId) -> usize {
642        a.index()
643    }
644    fn from_index(&self, ix: usize) -> Self::NodeId {
645        Ix::new(ix)
646    }
647}
648
649impl<N, E, Ty, Ix> NodeCompactIndexable for Csr<N, E, Ty, Ix>
650where
651    Ty: EdgeType,
652    Ix: IndexType,
653{
654}
655
656impl<N, E, Ty, Ix> Index<NodeIndex<Ix>> for Csr<N, E, Ty, Ix>
657where
658    Ty: EdgeType,
659    Ix: IndexType,
660{
661    type Output = N;
662
663    fn index(&self, ix: NodeIndex<Ix>) -> &N {
664        &self.node_weights[ix.index()]
665    }
666}
667
668impl<N, E, Ty, Ix> IndexMut<NodeIndex<Ix>> for Csr<N, E, Ty, Ix>
669where
670    Ty: EdgeType,
671    Ix: IndexType,
672{
673    fn index_mut(&mut self, ix: NodeIndex<Ix>) -> &mut N {
674        &mut self.node_weights[ix.index()]
675    }
676}
677
678#[derive(Debug, Clone)]
679pub struct NodeIdentifiers<Ix = DefaultIx> {
680    r: Range<usize>,
681    ty: PhantomData<Ix>,
682}
683
684impl<Ix> Iterator for NodeIdentifiers<Ix>
685where
686    Ix: IndexType,
687{
688    type Item = NodeIndex<Ix>;
689
690    fn next(&mut self) -> Option<Self::Item> {
691        self.r.next().map(Ix::new)
692    }
693
694    fn size_hint(&self) -> (usize, Option<usize>) {
695        self.r.size_hint()
696    }
697}
698
699impl<N, E, Ty, Ix> IntoNodeIdentifiers for &Csr<N, E, Ty, Ix>
700where
701    Ty: EdgeType,
702    Ix: IndexType,
703{
704    type NodeIdentifiers = NodeIdentifiers<Ix>;
705    fn node_identifiers(self) -> Self::NodeIdentifiers {
706        NodeIdentifiers {
707            r: 0..self.node_count(),
708            ty: PhantomData,
709        }
710    }
711}
712
713impl<N, E, Ty, Ix> NodeCount for Csr<N, E, Ty, Ix>
714where
715    Ty: EdgeType,
716    Ix: IndexType,
717{
718    fn node_count(&self) -> usize {
719        (*self).node_count()
720    }
721}
722
723impl<N, E, Ty, Ix> EdgeCount for Csr<N, E, Ty, Ix>
724where
725    Ty: EdgeType,
726    Ix: IndexType,
727{
728    #[inline]
729    fn edge_count(&self) -> usize {
730        self.edge_count()
731    }
732}
733
734impl<N, E, Ty, Ix> GraphProp for Csr<N, E, Ty, Ix>
735where
736    Ty: EdgeType,
737    Ix: IndexType,
738{
739    type EdgeType = Ty;
740}
741
742impl<'a, N, E, Ty, Ix> IntoNodeReferences for &'a Csr<N, E, Ty, Ix>
743where
744    Ty: EdgeType,
745    Ix: IndexType,
746{
747    type NodeRef = (NodeIndex<Ix>, &'a N);
748    type NodeReferences = NodeReferences<'a, N, Ix>;
749    fn node_references(self) -> Self::NodeReferences {
750        NodeReferences {
751            iter: self.node_weights.iter().enumerate(),
752            ty: PhantomData,
753        }
754    }
755}
756
757/// Iterator over all nodes of a graph.
758#[derive(Debug, Clone)]
759pub struct NodeReferences<'a, N: 'a, Ix: IndexType = DefaultIx> {
760    iter: Enumerate<SliceIter<'a, N>>,
761    ty: PhantomData<Ix>,
762}
763
764impl<'a, N, Ix> Iterator for NodeReferences<'a, N, Ix>
765where
766    Ix: IndexType,
767{
768    type Item = (NodeIndex<Ix>, &'a N);
769
770    fn next(&mut self) -> Option<Self::Item> {
771        self.iter.next().map(|(i, weight)| (Ix::new(i), weight))
772    }
773
774    fn size_hint(&self) -> (usize, Option<usize>) {
775        self.iter.size_hint()
776    }
777}
778
779impl<N, Ix> DoubleEndedIterator for NodeReferences<'_, N, Ix>
780where
781    Ix: IndexType,
782{
783    fn next_back(&mut self) -> Option<Self::Item> {
784        self.iter
785            .next_back()
786            .map(|(i, weight)| (Ix::new(i), weight))
787    }
788}
789
790impl<N, Ix> ExactSizeIterator for NodeReferences<'_, N, Ix> where Ix: IndexType {}
791
792/// The adjacency matrix for **Csr** is a bitmap that's computed by
793/// `.adjacency_matrix()`.
794impl<N, E, Ty, Ix> GetAdjacencyMatrix for &Csr<N, E, Ty, Ix>
795where
796    Ix: IndexType,
797    Ty: EdgeType,
798{
799    type AdjMatrix = FixedBitSet;
800
801    fn adjacency_matrix(&self) -> FixedBitSet {
802        let n = self.node_count();
803        let mut matrix = FixedBitSet::with_capacity(n * n);
804        for edge in self.edge_references() {
805            let i = n * edge.source().index() + edge.target().index();
806            matrix.put(i);
807
808            if !self.is_directed() {
809                let j = edge.source().index() + n * edge.target().index();
810                matrix.put(j);
811            }
812        }
813        matrix
814    }
815
816    fn is_adjacent(&self, matrix: &FixedBitSet, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> bool {
817        let n = self.node_count();
818        let index = n * a.index() + b.index();
819        matrix.contains(index)
820    }
821}
822
823/*
824 *
825Example
826
827[ a 0 b
828  c d e
829  0 0 f ]
830
831Values: [a, b, c, d, e, f]
832Column: [0, 2, 0, 1, 2, 2]
833Row   : [0, 2, 5]   <- value index of row start
834
835 * */
836
837#[cfg(test)]
838mod tests {
839    use super::Csr;
840    use crate::algo::bellman_ford;
841    use crate::algo::find_negative_cycle;
842    use crate::algo::tarjan_scc;
843    use crate::visit::Dfs;
844    use crate::visit::VisitMap;
845    use crate::Undirected;
846
847    #[test]
848    fn csr1() {
849        let mut m: Csr = Csr::with_nodes(3);
850        m.add_edge(0, 0, ());
851        m.add_edge(1, 2, ());
852        m.add_edge(2, 2, ());
853        m.add_edge(0, 2, ());
854        m.add_edge(1, 0, ());
855        m.add_edge(1, 1, ());
856        println!("{:?}", m);
857        assert_eq!(&m.column, &[0, 2, 0, 1, 2, 2]);
858        assert_eq!(&m.row, &[0, 2, 5, 6]);
859
860        let added = m.add_edge(1, 2, ());
861        assert!(!added);
862        assert_eq!(&m.column, &[0, 2, 0, 1, 2, 2]);
863        assert_eq!(&m.row, &[0, 2, 5, 6]);
864
865        assert_eq!(m.neighbors_slice(1), &[0, 1, 2]);
866        assert_eq!(m.node_count(), 3);
867        assert_eq!(m.edge_count(), 6);
868    }
869
870    #[test]
871    fn csr_undirected() {
872        /*
873           [ 1 . 1
874             . . 1
875             1 1 1 ]
876        */
877
878        let mut m: Csr<(), (), Undirected> = Csr::with_nodes(3);
879        m.add_edge(0, 0, ());
880        m.add_edge(0, 2, ());
881        m.add_edge(1, 2, ());
882        m.add_edge(2, 2, ());
883        println!("{:?}", m);
884        assert_eq!(&m.column, &[0, 2, 2, 0, 1, 2]);
885        assert_eq!(&m.row, &[0, 2, 3, 6]);
886        assert_eq!(m.node_count(), 3);
887        assert_eq!(m.edge_count(), 4);
888    }
889
890    #[should_panic]
891    #[test]
892    fn csr_from_error_1() {
893        // not sorted in source
894        let m: Csr = Csr::from_sorted_edges(&[(0, 1), (1, 0), (0, 2)]).unwrap();
895        println!("{:?}", m);
896    }
897
898    #[should_panic]
899    #[test]
900    fn csr_from_error_2() {
901        // not sorted in target
902        let m: Csr = Csr::from_sorted_edges(&[(0, 1), (1, 0), (1, 2), (1, 1)]).unwrap();
903        println!("{:?}", m);
904    }
905
906    #[test]
907    fn csr_from() {
908        let m: Csr =
909            Csr::from_sorted_edges(&[(0, 1), (0, 2), (1, 0), (1, 1), (2, 2), (2, 4)]).unwrap();
910        println!("{:?}", m);
911        assert_eq!(m.neighbors_slice(0), &[1, 2]);
912        assert_eq!(m.neighbors_slice(1), &[0, 1]);
913        assert_eq!(m.neighbors_slice(2), &[2, 4]);
914        assert_eq!(m.node_count(), 5);
915        assert_eq!(m.edge_count(), 6);
916    }
917
918    #[test]
919    fn csr_dfs() {
920        let mut m: Csr = Csr::from_sorted_edges(&[
921            (0, 1),
922            (0, 2),
923            (1, 0),
924            (1, 1),
925            (1, 3),
926            (2, 2),
927            // disconnected subgraph
928            (4, 4),
929            (4, 5),
930        ])
931        .unwrap();
932        println!("{:?}", m);
933        let mut dfs = Dfs::new(&m, 0);
934        while dfs.next(&m).is_some() {}
935        for i in 0..m.node_count() - 2 {
936            assert!(dfs.discovered.is_visited(&i), "visited {}", i)
937        }
938        assert!(!dfs.discovered[4]);
939        assert!(!dfs.discovered[5]);
940
941        m.add_edge(1, 4, ());
942        println!("{:?}", m);
943
944        dfs.reset(&m);
945        dfs.move_to(0);
946        while dfs.next(&m).is_some() {}
947
948        for i in 0..m.node_count() {
949            assert!(dfs.discovered[i], "visited {}", i)
950        }
951    }
952
953    #[test]
954    fn csr_tarjan() {
955        let m: Csr = Csr::from_sorted_edges(&[
956            (0, 1),
957            (0, 2),
958            (1, 0),
959            (1, 1),
960            (1, 3),
961            (2, 2),
962            (2, 4),
963            (4, 4),
964            (4, 5),
965            (5, 2),
966        ])
967        .unwrap();
968        println!("{:?}", m);
969        println!("{:?}", tarjan_scc(&m));
970    }
971
972    #[test]
973    fn test_bellman_ford() {
974        let m: Csr<(), _> = Csr::from_sorted_edges(&[
975            (0, 1, 0.5),
976            (0, 2, 2.),
977            (1, 0, 1.),
978            (1, 1, 1.),
979            (1, 2, 1.),
980            (1, 3, 1.),
981            (2, 3, 3.),
982            (4, 5, 1.),
983            (5, 7, 2.),
984            (6, 7, 1.),
985            (7, 8, 3.),
986        ])
987        .unwrap();
988        println!("{:?}", m);
989        let result = bellman_ford(&m, 0).unwrap();
990        println!("{:?}", result);
991        let answer = [0., 0.5, 1.5, 1.5];
992        assert_eq!(&answer, &result.distances[..4]);
993        assert!(result.distances[4..].iter().all(|&x| f64::is_infinite(x)));
994    }
995
996    #[test]
997    fn test_bellman_ford_neg_cycle() {
998        let m: Csr<(), _> = Csr::from_sorted_edges(&[
999            (0, 1, 0.5),
1000            (0, 2, 2.),
1001            (1, 0, 1.),
1002            (1, 1, -1.),
1003            (1, 2, 1.),
1004            (1, 3, 1.),
1005            (2, 3, 3.),
1006        ])
1007        .unwrap();
1008        let result = bellman_ford(&m, 0);
1009        assert!(result.is_err());
1010    }
1011
1012    #[test]
1013    fn test_find_neg_cycle1() {
1014        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1015            (0, 1, 0.5),
1016            (0, 2, 2.),
1017            (1, 0, 1.),
1018            (1, 1, -1.),
1019            (1, 2, 1.),
1020            (1, 3, 1.),
1021            (2, 3, 3.),
1022        ])
1023        .unwrap();
1024        let result = find_negative_cycle(&m, 0);
1025        assert_eq!(result, Some([1].to_vec()));
1026    }
1027
1028    #[test]
1029    fn test_find_neg_cycle2() {
1030        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1031            (0, 1, 0.5),
1032            (0, 2, 2.),
1033            (1, 0, 1.),
1034            (1, 2, 1.),
1035            (1, 3, 1.),
1036            (2, 3, 3.),
1037        ])
1038        .unwrap();
1039        let result = find_negative_cycle(&m, 0);
1040        assert_eq!(result, None);
1041    }
1042
1043    #[test]
1044    fn test_find_neg_cycle3() {
1045        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1046            (0, 1, 1.),
1047            (0, 2, 1.),
1048            (0, 3, 1.),
1049            (1, 3, 1.),
1050            (2, 1, 1.),
1051            (3, 2, -3.),
1052        ])
1053        .unwrap();
1054        let result = find_negative_cycle(&m, 0);
1055        assert_eq!(result, Some([1, 3, 2].to_vec()));
1056    }
1057
1058    #[test]
1059    fn test_find_neg_cycle4() {
1060        let m: Csr<(), _> = Csr::from_sorted_edges(&[(0, 0, -1.)]).unwrap();
1061        let result = find_negative_cycle(&m, 0);
1062        assert_eq!(result, Some([0].to_vec()));
1063    }
1064
1065    #[test]
1066    fn test_edge_references() {
1067        use crate::visit::EdgeRef;
1068        use crate::visit::IntoEdgeReferences;
1069        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1070            (0, 1, 0.5),
1071            (0, 2, 2.),
1072            (1, 0, 1.),
1073            (1, 1, 1.),
1074            (1, 2, 1.),
1075            (1, 3, 1.),
1076            (2, 3, 3.),
1077            (4, 5, 1.),
1078            (5, 7, 2.),
1079            (6, 7, 1.),
1080            (7, 8, 3.),
1081        ])
1082        .unwrap();
1083        let mut copy = Vec::new();
1084        for e in m.edge_references() {
1085            copy.push((e.source(), e.target(), *e.weight()));
1086            println!("{:?}", e);
1087        }
1088        let m2: Csr<(), _> = Csr::from_sorted_edges(&copy).unwrap();
1089        assert_eq!(&m.row, &m2.row);
1090        assert_eq!(&m.column, &m2.column);
1091        assert_eq!(&m.edges, &m2.edges);
1092    }
1093
1094    #[test]
1095    fn test_add_node() {
1096        let mut g: Csr = Csr::new();
1097        let a = g.add_node(());
1098        let b = g.add_node(());
1099        let c = g.add_node(());
1100
1101        assert!(g.add_edge(a, b, ()));
1102        assert!(g.add_edge(b, c, ()));
1103        assert!(g.add_edge(c, a, ()));
1104
1105        println!("{:?}", g);
1106
1107        assert_eq!(g.node_count(), 3);
1108
1109        assert_eq!(g.neighbors_slice(a), &[b]);
1110        assert_eq!(g.neighbors_slice(b), &[c]);
1111        assert_eq!(g.neighbors_slice(c), &[a]);
1112
1113        assert_eq!(g.edge_count(), 3);
1114    }
1115
1116    #[test]
1117    fn test_add_node_with_existing_edges() {
1118        let mut g: Csr = Csr::new();
1119        let a = g.add_node(());
1120        let b = g.add_node(());
1121
1122        assert!(g.add_edge(a, b, ()));
1123
1124        let c = g.add_node(());
1125
1126        println!("{:?}", g);
1127
1128        assert_eq!(g.node_count(), 3);
1129
1130        assert_eq!(g.neighbors_slice(a), &[b]);
1131        assert_eq!(g.neighbors_slice(b), &[]);
1132        assert_eq!(g.neighbors_slice(c), &[]);
1133
1134        assert_eq!(g.edge_count(), 1);
1135    }
1136
1137    #[test]
1138    fn test_node_references() {
1139        use crate::visit::IntoNodeReferences;
1140        let mut g: Csr<u32> = Csr::new();
1141        g.add_node(42);
1142        g.add_node(3);
1143        g.add_node(44);
1144
1145        let mut refs = g.node_references();
1146        assert_eq!(refs.next(), Some((0, &42)));
1147        assert_eq!(refs.next(), Some((1, &3)));
1148        assert_eq!(refs.next(), Some((2, &44)));
1149        assert_eq!(refs.next(), None);
1150    }
1151}