From 7ed2b9fc506fbc9a5931b93c141349ecb82cd9dc Mon Sep 17 00:00:00 2001 From: "Kevin R. Thornton" Date: Wed, 9 Nov 2022 10:20:29 -0800 Subject: [PATCH] feat: column slice getters for tables (#404) --- src/_macros.rs | 32 +++++++++++ src/edge_table.rs | 25 ++++++++ src/individual_table.rs | 7 +++ src/migration_table.rs | 37 ++++++++++++ src/mutation_table.rs | 25 ++++++++ src/node_table.rs | 37 ++++++++++++ src/site_table.rs | 7 +++ tests/test_tables.rs | 122 +++++++++++++++++++++++++++++++++++++--- 8 files changed, 285 insertions(+), 7 deletions(-) diff --git a/src/_macros.rs b/src/_macros.rs index d6055e85..b6a0edfa 100644 --- a/src/_macros.rs +++ b/src/_macros.rs @@ -1159,6 +1159,38 @@ macro_rules! optional_container_comparison { }; } +macro_rules! build_table_column_slice_getter { + ($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => { + $(#[$attr])* + pub fn $name(&self) -> &[$cast] { + // Caveat: num_rows is u64 but we need usize + // The conversion is fallible but unlikely. + let num_rows = + usize::try_from(self.num_rows()).expect("conversion of num_rows to usize failed"); + let ptr = self.as_ref().$column as *const $cast; + // SAFETY: tables are initialzed, num rows comes + // from the C back end. + unsafe { std::slice::from_raw_parts(ptr, num_rows) } + } + }; +} + +macro_rules! build_table_column_slice_mut_getter { + ($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => { + $(#[$attr])* + pub fn $name(&mut self) -> &mut [$cast] { + // Caveat: num_rows is u64 but we need usize + // The conversion is fallible but unlikely. + let num_rows = + usize::try_from(self.num_rows()).expect("conversion of num_rows to usize failed"); + let ptr = self.as_ref().$column as *mut $cast; + // SAFETY: tables are initialzed, num rows comes + // from the C back end. + unsafe { std::slice::from_raw_parts_mut(ptr, num_rows) } + } + }; +} + #[cfg(test)] mod test { use crate::error::TskitError; diff --git a/src/edge_table.rs b/src/edge_table.rs index ca4a4e0e..178d4346 100644 --- a/src/edge_table.rs +++ b/src/edge_table.rs @@ -312,6 +312,31 @@ impl EdgeTable { }; Some(view) } + + build_table_column_slice_getter!( + /// Get the left column as a slice + => left, left_slice, Position); + build_table_column_slice_getter!( + /// Get the left column as a slice of [`f64`] + => left, left_slice_raw, f64); + build_table_column_slice_getter!( + /// Get the right column as a slice + => right, right_slice, Position); + build_table_column_slice_getter!( + /// Get the left column as a slice of [`f64`] + => right, right_slice_raw, f64); + build_table_column_slice_getter!( + /// Get the parent column as a slice + => parent, parent_slice, NodeId); + build_table_column_slice_getter!( + /// Get the parent column as a slice of [`crate::bindings::tsk_id_t`] + => parent, parent_slice_raw, ll_bindings::tsk_id_t); + build_table_column_slice_getter!( + /// Get the child column as a slice + => child, child_slice, NodeId); + build_table_column_slice_getter!( + /// Get the child column as a slice of [`crate::bindings::tsk_id_t`] + => child, child_slice_raw, ll_bindings::tsk_id_t); } build_owned_table_type!( diff --git a/src/individual_table.rs b/src/individual_table.rs index 9b140023..2bfdff17 100644 --- a/src/individual_table.rs +++ b/src/individual_table.rs @@ -444,6 +444,13 @@ match tables.individuals().metadata::(0.into()) }; Some(view) } + + build_table_column_slice_getter!( + /// Get the flags column as a slice + => flags, flags_slice, IndividualFlags); + build_table_column_slice_getter!( + /// Get the flags column as a slice + => flags, flags_slice_raw, ll_bindings::tsk_flags_t); } build_owned_table_type!( diff --git a/src/migration_table.rs b/src/migration_table.rs index 5aad4699..cc1d04b1 100644 --- a/src/migration_table.rs +++ b/src/migration_table.rs @@ -365,6 +365,43 @@ impl MigrationTable { }; Some(view) } + + build_table_column_slice_getter!( + /// Get the left column as a slice + => left, left_slice, Position); + build_table_column_slice_getter!( + /// Get the left column as a slice + => left, left_slice_raw, f64); + build_table_column_slice_getter!( + /// Get the right column as a slice + => right, right_slice, Position); + build_table_column_slice_getter!( + /// Get the right column as a slice + => right, right_slice_raw, f64); + build_table_column_slice_getter!( + /// Get the time column as a slice + => time, time_slice, Time); + build_table_column_slice_getter!( + /// Get the time column as a slice + => time, time_slice_raw, f64); + build_table_column_slice_getter!( + /// Get the node column as a slice + => node, node_slice, NodeId); + build_table_column_slice_getter!( + /// Get the node column as a slice + => node, node_slice_raw, ll_bindings::tsk_id_t); + build_table_column_slice_getter!( + /// Get the source column as a slice + => source, source_slice, PopulationId); + build_table_column_slice_getter!( + /// Get the source column as a slice + => source, source_slice_raw, ll_bindings::tsk_id_t); + build_table_column_slice_getter!( + /// Get the dest column as a slice + => dest, dest_slice, PopulationId); + build_table_column_slice_getter!( + /// Get the dest column as a slice + => dest, dest_slice_raw, ll_bindings::tsk_id_t); } build_owned_table_type!( diff --git a/src/mutation_table.rs b/src/mutation_table.rs index 896ae191..12f7060d 100644 --- a/src/mutation_table.rs +++ b/src/mutation_table.rs @@ -344,6 +344,31 @@ impl MutationTable { }; Some(view) } + + build_table_column_slice_getter!( + /// Get the node column as a slice + => node, node_slice, NodeId); + build_table_column_slice_getter!( + /// Get the node column as a slice + => node, node_slice_raw, crate::tsk_id_t); + build_table_column_slice_getter!( + /// Get the site column as a slice + => site, site_slice, SiteId); + build_table_column_slice_getter!( + /// Get the site column as a slice + => site, site_slice_raw, crate::tsk_id_t); + build_table_column_slice_getter!( + /// Get the time column as a slice + => time, time_slice, Time); + build_table_column_slice_getter!( + /// Get the time column as a slice + => time, time_slice_raw, f64); + build_table_column_slice_getter!( + /// Get the parent column as a slice + => parent, parent_slice, MutationId); + build_table_column_slice_getter!( + /// Get the parent column as a slice + => parent, parent_slice_raw, crate::tsk_id_t); } build_owned_table_type!( diff --git a/src/node_table.rs b/src/node_table.rs index d07eff8c..ccb25a18 100644 --- a/src/node_table.rs +++ b/src/node_table.rs @@ -520,6 +520,43 @@ impl NodeTable { .map(|row| row.id) .collect::>() } + + build_table_column_slice_getter!( + /// Get the time column as a slice + => time, time_slice, Time); + build_table_column_slice_getter!( + /// Get the time column as a slice + => time, time_slice_raw, f64); + build_table_column_slice_mut_getter!( + /// Get the time column as a mutable slice + => time, time_slice_mut, Time); + build_table_column_slice_mut_getter!( + /// Get the time column as a mutable slice + => time, time_slice_raw_mut, f64); + build_table_column_slice_getter!( + /// Get the flags column as a slice + => flags, flags_slice, NodeFlags); + build_table_column_slice_getter!( + /// Get the flags column as a slice + => flags, flags_slice_raw, ll_bindings::tsk_flags_t); + build_table_column_slice_mut_getter!( + /// Get the flags column as a mutable slice + => flags, flags_slice_mut, NodeFlags); + build_table_column_slice_mut_getter!( + /// Get the flags column as a mutable slice + => flags, flags_slice_raw_mut, ll_bindings::tsk_flags_t); + build_table_column_slice_getter!( + /// Get the individual column as a slice + => individual, individual_slice, IndividualId); + build_table_column_slice_getter!( + /// Get the individual column as a slice + => individual, individual_slice_raw, crate::tsk_id_t); + build_table_column_slice_getter!( + /// Get the population column as a slice + => population, population_slice, PopulationId); + build_table_column_slice_getter!( + /// Get the population column as a slice + => population, population_slice_raw, crate::tsk_id_t); } build_owned_table_type!( diff --git a/src/site_table.rs b/src/site_table.rs index 36f3d638..03f648dc 100644 --- a/src/site_table.rs +++ b/src/site_table.rs @@ -263,6 +263,13 @@ impl SiteTable { }; Some(view) } + + build_table_column_slice_getter!( + /// Get the position column as a slice + => position, position_slice, Position); + build_table_column_slice_getter!( + /// Get the position column as a slice + => position, position_slice_raw, f64); } build_owned_table_type!( diff --git a/tests/test_tables.rs b/tests/test_tables.rs index 1bfea8c6..531b489e 100644 --- a/tests/test_tables.rs +++ b/tests/test_tables.rs @@ -71,10 +71,39 @@ mod test_adding_rows_without_metadata { Err(e) => panic!("Err from tables.{}: {:?}", stringify!(adder), e) } assert_eq!(tables.$table().iter().count(), 2); + tables } }}; } + macro_rules! compare_column_to_raw_column { + ($table: expr, $col: ident, $raw: ident) => { + assert_eq!( + $table.$col().len(), + usize::try_from($table.num_rows()).unwrap() + ); + assert_eq!( + $table.$raw().len(), + usize::try_from($table.num_rows()).unwrap() + ); + assert!($table + .$col() + .iter() + .zip($table.$raw().iter()) + .all(|(a, b)| a == b)) + }; + } + + macro_rules! compare_column_to_row { + ($table: expr, $col: ident, $target: ident) => { + assert!($table + .$col() + .iter() + .zip($table.iter()) + .all(|(c, r)| c == &r.$target)); + }; + } + // NOTE: all functions arguments for adding rows are Into // where T is one of our new types. // Further, functions taking multiple inputs of T are defined @@ -82,7 +111,18 @@ mod test_adding_rows_without_metadata { #[test] fn test_adding_edge() { - add_row_without_metadata!(edges, add_edge, 0.1, 0.5, 0, 1); // left, right, parent, child + { + let tables = add_row_without_metadata!(edges, add_edge, 0.1, 0.5, 0, 1); // left, right, parent, child + compare_column_to_raw_column!(tables.edges(), left_slice, left_slice_raw); + compare_column_to_raw_column!(tables.edges(), right_slice, right_slice_raw); + compare_column_to_raw_column!(tables.edges(), parent_slice, parent_slice_raw); + compare_column_to_raw_column!(tables.edges(), child_slice, child_slice_raw); + + compare_column_to_row!(tables.edges(), left_slice, left); + compare_column_to_row!(tables.edges(), right_slice, right); + compare_column_to_row!(tables.edges(), parent_slice, parent); + compare_column_to_row!(tables.edges(), child_slice, child); + } add_row_without_metadata!(edges, add_edge, tskit::Position::from(0.1), 0.5, 0, 1); // left, right, parent, child add_row_without_metadata!(edges, add_edge, 0.1, tskit::Position::from(0.5), 0, 1); // left, right, parent, child add_row_without_metadata!( @@ -105,8 +145,30 @@ mod test_adding_rows_without_metadata { #[test] fn test_adding_node() { - add_row_without_metadata!(nodes, add_node, 0, 0.1, -1, -1); // flags, time, population, - // individual + { + let tables = + add_row_without_metadata!(nodes, add_node, tskit::TSK_NODE_IS_SAMPLE, 0.1, -1, -1); // flags, time, population, + // individual + assert!(tables + .nodes() + .flags_slice() + .iter() + .zip(tables.nodes().flags_slice_raw().iter()) + .all(|(a, b)| a.bits() == *b)); + compare_column_to_raw_column!(tables.nodes(), time_slice, time_slice_raw); + compare_column_to_raw_column!(tables.nodes(), population_slice, population_slice_raw); + compare_column_to_raw_column!(tables.nodes(), individual_slice, individual_slice_raw); + + assert!(tables + .nodes() + .flags_slice() + .iter() + .zip(tables.nodes().iter()) + .all(|(c, r)| c == &r.flags)); + compare_column_to_row!(tables.nodes(), time_slice, time); + compare_column_to_row!(tables.nodes(), population_slice, population); + compare_column_to_row!(tables.nodes(), individual_slice, individual); + } add_row_without_metadata!( nodes, add_node, @@ -120,7 +182,11 @@ mod test_adding_rows_without_metadata { #[test] fn test_adding_site() { // No ancestral state - add_row_without_metadata!(sites, add_site, 2. / 3., None); + { + let tables = add_row_without_metadata!(sites, add_site, 2. / 3., None); + compare_column_to_raw_column!(tables.sites(), position_slice, position_slice_raw); + compare_column_to_row!(tables.sites(), position_slice, position); + } add_row_without_metadata!(sites, add_site, tskit::Position::from(2. / 3.), None); add_row_without_metadata!(sites, add_site, 2. / 3., Some(&[1_u8])); add_row_without_metadata!( @@ -136,14 +202,40 @@ mod test_adding_rows_without_metadata { // site, node, parent mutation, time, derived_state // Each value is a different Into so we skip doing // permutations - add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, None); + { + let tables = add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, None); + compare_column_to_raw_column!(tables.mutations(), node_slice, node_slice_raw); + compare_column_to_raw_column!(tables.mutations(), time_slice, time_slice_raw); + compare_column_to_raw_column!(tables.mutations(), site_slice, site_slice_raw); + compare_column_to_raw_column!(tables.mutations(), parent_slice, parent_slice_raw); + + compare_column_to_row!(tables.mutations(), node_slice, node); + compare_column_to_row!(tables.mutations(), time_slice, time); + compare_column_to_row!(tables.mutations(), site_slice, site); + compare_column_to_row!(tables.mutations(), parent_slice, parent); + } + add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, Some(&[23_u8])); } #[test] fn test_adding_individual() { // flags, location, parents - add_row_without_metadata!(individuals, add_individual, 0, None, None); + { + let tables = add_row_without_metadata!(individuals, add_individual, 0, None, None); + assert!(tables + .individuals() + .flags_slice() + .iter() + .zip(tables.individuals().flags_slice_raw().iter()) + .all(|(a, b)| a.bits() == *b)); + assert!(tables + .individuals() + .flags_slice() + .iter() + .zip(tables.individuals().iter()) + .all(|(c, r)| c == &r.flags)); + } add_row_without_metadata!( individuals, add_individual, @@ -179,7 +271,23 @@ mod test_adding_rows_without_metadata { fn test_adding_migration() { // migration table // (left, right), node, (source, dest), time - add_row_without_metadata!(migrations, add_migration, (0., 1.), 0, (0, 1), 0.0); + { + let tables = + add_row_without_metadata!(migrations, add_migration, (0., 1.), 0, (0, 1), 0.0); + compare_column_to_raw_column!(tables.migrations(), left_slice, left_slice_raw); + compare_column_to_raw_column!(tables.migrations(), right_slice, right_slice_raw); + compare_column_to_raw_column!(tables.migrations(), node_slice, node_slice_raw); + compare_column_to_raw_column!(tables.migrations(), time_slice, time_slice_raw); + compare_column_to_raw_column!(tables.migrations(), source_slice, source_slice_raw); + compare_column_to_raw_column!(tables.migrations(), dest_slice, dest_slice_raw); + + compare_column_to_row!(tables.migrations(), left_slice, left); + compare_column_to_row!(tables.migrations(), right_slice, right); + compare_column_to_row!(tables.migrations(), node_slice, node); + compare_column_to_row!(tables.migrations(), time_slice, time); + compare_column_to_row!(tables.migrations(), source_slice, source); + compare_column_to_row!(tables.migrations(), dest_slice, dest); + } add_row_without_metadata!( migrations, add_migration,