From f985995dbc55e212da881214d488dd2bc1fb0d8b Mon Sep 17 00:00:00 2001 From: goulart-paul Date: Wed, 27 Nov 2024 15:59:52 +0000 Subject: [PATCH] Utilities for converting to CSC canonicalization / deduplication. Fixes #132. --- src/algebra/csc/core.rs | 261 +++++++++++++++++++++++++++++++++++-- src/python/cscmatrix_py.rs | 10 +- 2 files changed, 256 insertions(+), 15 deletions(-) diff --git a/src/algebra/csc/core.rs b/src/algebra/csc/core.rs index 72ef724..0d44f2c 100644 --- a/src/algebra/csc/core.rs +++ b/src/algebra/csc/core.rs @@ -303,8 +303,102 @@ where Symmetric { src: self } } - /// Check that matrix data is correctly formatted. + /// Check that matrix data is canonically formatted. pub fn check_format(&self) -> Result<(), SparseFormatError> { + self.check_dimensions()?; + + //check for rowval monotonicity within each column + for col in 0..self.n { + let rng = self.colptr[col]..self.colptr[col + 1]; + if self.rowval[rng].windows(2).any(|c| c[0] >= c[1]) { + return Err(SparseFormatError::BadRowval); + } + } + //check for row values out of bounds + if !self.rowval.iter().all(|r| r < &self.m) { + return Err(SparseFormatError::BadRowval); + } + + Ok(()) + } + + /// Put matrix into standard ('canonical') form, operating in place. This function + /// sorts data within each column by row index, and removes any duplicates. + /// Does not remove structural zeros. + /// + /// # Panics + /// Panics if the matrix initial dimensions are incompatible. + /// + pub fn canonicalize(&mut self) -> Result<(), SparseFormatError> { + self.check_dimensions()?; + self.sort_indices()?; + self.deduplicate() + } + + /// Adds together repeated entries in the same column. Input must + /// already be in column sorted order. + fn sort_indices(&mut self) -> Result<(), SparseFormatError> { + let mut tempdata: Vec<(usize, T)> = Vec::new(); + + for col in 0..self.n { + let start = self.colptr[col]; + let stop = self.colptr[col + 1]; + + let nzval = &mut self.nzval[start..stop]; + let rowval = &mut self.rowval[start..stop]; + + tempdata.resize(stop - start, (0, T::zero())); + + for (i, (r, v)) in zip(rowval.iter(), nzval.iter()).enumerate() { + tempdata[i] = (*r, *v); + } + tempdata.sort_by_key(|&(r, _)| r); + + for (i, (r, v)) in tempdata.iter().enumerate() { + rowval[i] = *r; + nzval[i] = *v; + } + } + + Ok(()) + } + + /// Adds together repeated entries in the same column. Input must + /// already be in column sorted order. + fn deduplicate(&mut self) -> Result<(), SparseFormatError> { + let mut nnz = 0; + let mut stop = 0; + + for col in 0..self.n { + let mut ptr = stop; + stop = self.colptr[col + 1]; + + while ptr < stop { + let thisrow = self.rowval[ptr]; + let mut accum = self.nzval[ptr]; + ptr += 1; + + while (ptr < stop) && (self.rowval[ptr] == thisrow) { + accum = accum + self.nzval[ptr]; + ptr += 1; + } + self.rowval[nnz] = thisrow; + self.nzval[nnz] = accum; + nnz += 1; + } + self.colptr[col + 1] = nnz; + } + + self.rowval.truncate(nnz); + self.nzval.truncate(nnz); + + Ok(()) + } + + /// Check that for dimensional consistency. Private since users should + /// check everything via check_format, and the canonicalization functions + /// must at least check dimensions before running. + fn check_dimensions(&self) -> Result<(), SparseFormatError> { if self.rowval.len() != self.nzval.len() { return Err(SparseFormatError::IncompatibleDimension); } @@ -320,21 +414,9 @@ where if self.colptr.windows(2).any(|c| c[0] > c[1]) { return Err(SparseFormatError::BadColptr); } - - //check for rowval monotonicity within each column - for col in 0..self.n { - let rng = self.colptr[col]..self.colptr[col + 1]; - if self.rowval[rng].windows(2).any(|c| c[0] >= c[1]) { - return Err(SparseFormatError::BadRowval); - } - } - //check for row values out of bounds - if !self.rowval.iter().all(|r| r < &self.m) { - return Err(SparseFormatError::BadRowval); - } - Ok(()) } + /// True if matrices if the same size and sparsity pattern pub fn is_equal_sparsity(&self, other: &Self) -> bool { self.size() == other.size() && self.colptr == other.colptr && self.rowval == other.rowval @@ -807,3 +889,154 @@ fn test_drop_zeros() { assert_eq!(A, B); } + +#[test] +fn test_sort_indices() { + let mut A = CscMatrix { + m: 4, + n: 3, + colptr: vec![0, 2, 4, 5], + rowval: vec![3, 1, 4, 2, 2], + nzval: vec![2.0, 3.0, 1.0, 4.0, 5.0], + }; + + A.sort_indices().unwrap(); + assert_eq!(A.rowval, vec![1, 3, 2, 4, 2]); + assert_eq!(A.nzval, vec![3.0, 2.0, 4.0, 1.0, 5.0]); + + //nothing to sort + A.sort_indices().unwrap(); + assert_eq!(A.rowval, vec![1, 3, 2, 4, 2]); + assert_eq!(A.nzval, vec![3.0, 2.0, 4.0, 1.0, 5.0]); +} + +#[test] +fn test_sort_indices_with_duplicates() { + let mut A = CscMatrix { + m: 4, + n: 2, + colptr: vec![0, 3, 5], + rowval: vec![3, 3, 1, 2, 4], + nzval: vec![2.0, 3.0, 1.0, 1.0, 4.0], + }; + + A.sort_indices().unwrap(); + assert_eq!(A.rowval, vec![1, 3, 3, 2, 4]); + assert_eq!(A.nzval, vec![1.0, 2.0, 3.0, 1.0, 4.0]); +} + +#[test] +fn test_deduplicate() { + let mut A = CscMatrix { + m: 4, + n: 2, + colptr: vec![0, 2, 4], + rowval: vec![1, 1, 2, 4], + nzval: vec![3.0, 2.0, 1.0, 4.0], + }; + + A.deduplicate().unwrap(); + assert_eq!(A.colptr, vec![0, 1, 3]); + assert_eq!(A.rowval, vec![1, 2, 4]); + assert_eq!(A.nzval, vec![5.0, 1.0, 4.0]); + + // nothing to deduplicate + A.deduplicate().unwrap(); + assert_eq!(A.colptr, vec![0, 1, 3]); + assert_eq!(A.rowval, vec![1, 2, 4]); + assert_eq!(A.nzval, vec![5.0, 1.0, 4.0]); +} + +#[test] +fn test_deduplicate_multiple_columns() { + let mut A = CscMatrix { + m: 4, + n: 3, + colptr: vec![0, 2, 4, 6], + rowval: vec![1, 1, 2, 4, 3, 3], + nzval: vec![3.0, 2.0, 1.0, 4.0, 5.0, 6.0], + }; + + A.deduplicate().unwrap(); + assert_eq!(A.colptr, vec![0, 1, 3, 4]); + assert_eq!(A.rowval, vec![1, 2, 4, 3]); + assert_eq!(A.nzval, vec![5.0, 1.0, 4.0, 11.0]); +} + +#[test] +fn test_deduplicate_1col() { + let mut A = CscMatrix { + m: 4, + n: 1, + colptr: vec![0, 3], + rowval: vec![1, 1, 4], + nzval: vec![2.0, 3.0, 4.0], + }; + + A.deduplicate().unwrap(); + assert_eq!(A.colptr, vec![0, 2]); + assert_eq!(A.rowval, vec![1, 4]); + assert_eq!(A.nzval, vec![5.0, 4.0]); +} + +#[test] +fn test_canonicalize() { + let mut A = CscMatrix { + m: 4, + n: 3, + colptr: vec![0, 3, 4, 7], + rowval: vec![2, 1, 1, 4, 3, 4, 3], + nzval: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + }; + + A.canonicalize().unwrap(); + assert_eq!(A.colptr, vec![0, 2, 3, 5]); + assert_eq!(A.rowval, vec![1, 2, 4, 3, 4]); + assert_eq!(A.nzval, vec![5.0, 1.0, 4.0, 12.0, 6.0]); +} + +#[test] +fn test_canonicalize_structural_zeros() { + let mut A = CscMatrix { + m: 4, + n: 3, + colptr: vec![0, 3, 4, 7], + rowval: vec![2, 1, 1, 4, 3, 4, 3], + nzval: vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, -5.0], + }; + + A.canonicalize().unwrap(); + assert_eq!(A.colptr, vec![0, 2, 3, 5]); + assert_eq!(A.rowval, vec![1, 2, 4, 3, 4]); + assert_eq!(A.nzval, vec![5.0, 1.0, 0.0, 0.0, 6.0]); +} + +#[test] +fn test_canonicalize_empty() { + let mut A: CscMatrix = CscMatrix { + m: 0, + n: 0, + colptr: vec![0], + rowval: vec![], + nzval: vec![], + }; + + A.canonicalize().unwrap(); + assert!(A.rowval.is_empty()); + assert!(A.nzval.is_empty()); +} + +#[test] +fn test_canonicalize_singleton() { + let mut A = CscMatrix { + m: 4, + n: 1, + colptr: vec![0, 1], + rowval: vec![2], + nzval: vec![5.0], + }; + + A.sort_indices().unwrap(); + assert_eq!(A.rowval, vec![2]); + assert_eq!(A.nzval, vec![5.0]); +} diff --git a/src/python/cscmatrix_py.rs b/src/python/cscmatrix_py.rs index b3e9b95..a2aeed8 100644 --- a/src/python/cscmatrix_py.rs +++ b/src/python/cscmatrix_py.rs @@ -27,7 +27,15 @@ impl<'a> FromPyObject<'a> for PyCscMatrix { let colptr: Vec = obj.getattr("indptr")?.extract()?; let shape: Vec = obj.getattr("shape")?.extract()?; - let mat = CscMatrix::new(shape[0], shape[1], colptr, rowval, nzval); + let mut mat = CscMatrix::new(shape[0], shape[1], colptr, rowval, nzval); + + // if the python object was non in standard format, force the rust + // object to still be nicely formatted + let is_canonical: bool = obj.getattr("has_canonical_format")?.extract()?; + + if !is_canonical { + let _ = mat.canonicalize(); + } Ok(PyCscMatrix(mat)) }