Skip to content

Commit

Permalink
feat: column slice getters for tables (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen authored Nov 9, 2022
1 parent aa94e1e commit 7ed2b9f
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 7 deletions.
32 changes: 32 additions & 0 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,38 @@ macro_rules! optional_container_comparison {
};
}

macro_rules! build_table_column_slice_getter {
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
$(#[$attr])*
pub fn $name(&self) -> &[$cast] {
// Caveat: num_rows is u64 but we need usize
// The conversion is fallible but unlikely.
let num_rows =
usize::try_from(self.num_rows()).expect("conversion of num_rows to usize failed");
let ptr = self.as_ref().$column as *const $cast;
// SAFETY: tables are initialzed, num rows comes
// from the C back end.
unsafe { std::slice::from_raw_parts(ptr, num_rows) }
}
};
}

macro_rules! build_table_column_slice_mut_getter {
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
$(#[$attr])*
pub fn $name(&mut self) -> &mut [$cast] {
// Caveat: num_rows is u64 but we need usize
// The conversion is fallible but unlikely.
let num_rows =
usize::try_from(self.num_rows()).expect("conversion of num_rows to usize failed");
let ptr = self.as_ref().$column as *mut $cast;
// SAFETY: tables are initialzed, num rows comes
// from the C back end.
unsafe { std::slice::from_raw_parts_mut(ptr, num_rows) }
}
};
}

#[cfg(test)]
mod test {
use crate::error::TskitError;
Expand Down
25 changes: 25 additions & 0 deletions src/edge_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,31 @@ impl EdgeTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the left column as a slice
=> left, left_slice, Position);
build_table_column_slice_getter!(
/// Get the left column as a slice of [`f64`]
=> left, left_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the right column as a slice
=> right, right_slice, Position);
build_table_column_slice_getter!(
/// Get the left column as a slice of [`f64`]
=> right, right_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the parent column as a slice
=> parent, parent_slice, NodeId);
build_table_column_slice_getter!(
/// Get the parent column as a slice of [`crate::bindings::tsk_id_t`]
=> parent, parent_slice_raw, ll_bindings::tsk_id_t);
build_table_column_slice_getter!(
/// Get the child column as a slice
=> child, child_slice, NodeId);
build_table_column_slice_getter!(
/// Get the child column as a slice of [`crate::bindings::tsk_id_t`]
=> child, child_slice_raw, ll_bindings::tsk_id_t);
}

build_owned_table_type!(
Expand Down
7 changes: 7 additions & 0 deletions src/individual_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@ match tables.individuals().metadata::<MutationMetadata>(0.into())
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice, IndividualFlags);
build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice_raw, ll_bindings::tsk_flags_t);
}

build_owned_table_type!(
Expand Down
37 changes: 37 additions & 0 deletions src/migration_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,43 @@ impl MigrationTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the left column as a slice
=> left, left_slice, Position);
build_table_column_slice_getter!(
/// Get the left column as a slice
=> left, left_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the right column as a slice
=> right, right_slice, Position);
build_table_column_slice_getter!(
/// Get the right column as a slice
=> right, right_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice, Time);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice, NodeId);
build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice_raw, ll_bindings::tsk_id_t);
build_table_column_slice_getter!(
/// Get the source column as a slice
=> source, source_slice, PopulationId);
build_table_column_slice_getter!(
/// Get the source column as a slice
=> source, source_slice_raw, ll_bindings::tsk_id_t);
build_table_column_slice_getter!(
/// Get the dest column as a slice
=> dest, dest_slice, PopulationId);
build_table_column_slice_getter!(
/// Get the dest column as a slice
=> dest, dest_slice_raw, ll_bindings::tsk_id_t);
}

build_owned_table_type!(
Expand Down
25 changes: 25 additions & 0 deletions src/mutation_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,31 @@ impl MutationTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice, NodeId);
build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice_raw, crate::tsk_id_t);
build_table_column_slice_getter!(
/// Get the site column as a slice
=> site, site_slice, SiteId);
build_table_column_slice_getter!(
/// Get the site column as a slice
=> site, site_slice_raw, crate::tsk_id_t);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice, Time);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the parent column as a slice
=> parent, parent_slice, MutationId);
build_table_column_slice_getter!(
/// Get the parent column as a slice
=> parent, parent_slice_raw, crate::tsk_id_t);
}

build_owned_table_type!(
Expand Down
37 changes: 37 additions & 0 deletions src/node_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,43 @@ impl NodeTable {
.map(|row| row.id)
.collect::<Vec<_>>()
}

build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice, Time);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice_raw, f64);
build_table_column_slice_mut_getter!(
/// Get the time column as a mutable slice
=> time, time_slice_mut, Time);
build_table_column_slice_mut_getter!(
/// Get the time column as a mutable slice
=> time, time_slice_raw_mut, f64);
build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice, NodeFlags);
build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice_raw, ll_bindings::tsk_flags_t);
build_table_column_slice_mut_getter!(
/// Get the flags column as a mutable slice
=> flags, flags_slice_mut, NodeFlags);
build_table_column_slice_mut_getter!(
/// Get the flags column as a mutable slice
=> flags, flags_slice_raw_mut, ll_bindings::tsk_flags_t);
build_table_column_slice_getter!(
/// Get the individual column as a slice
=> individual, individual_slice, IndividualId);
build_table_column_slice_getter!(
/// Get the individual column as a slice
=> individual, individual_slice_raw, crate::tsk_id_t);
build_table_column_slice_getter!(
/// Get the population column as a slice
=> population, population_slice, PopulationId);
build_table_column_slice_getter!(
/// Get the population column as a slice
=> population, population_slice_raw, crate::tsk_id_t);
}

build_owned_table_type!(
Expand Down
7 changes: 7 additions & 0 deletions src/site_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,13 @@ impl SiteTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the position column as a slice
=> position, position_slice, Position);
build_table_column_slice_getter!(
/// Get the position column as a slice
=> position, position_slice_raw, f64);
}

build_owned_table_type!(
Expand Down
122 changes: 115 additions & 7 deletions tests/test_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,58 @@ mod test_adding_rows_without_metadata {
Err(e) => panic!("Err from tables.{}: {:?}", stringify!(adder), e)
}
assert_eq!(tables.$table().iter().count(), 2);
tables
}
}};
}

macro_rules! compare_column_to_raw_column {
($table: expr, $col: ident, $raw: ident) => {
assert_eq!(
$table.$col().len(),
usize::try_from($table.num_rows()).unwrap()
);
assert_eq!(
$table.$raw().len(),
usize::try_from($table.num_rows()).unwrap()
);
assert!($table
.$col()
.iter()
.zip($table.$raw().iter())
.all(|(a, b)| a == b))
};
}

macro_rules! compare_column_to_row {
($table: expr, $col: ident, $target: ident) => {
assert!($table
.$col()
.iter()
.zip($table.iter())
.all(|(c, r)| c == &r.$target));
};
}

// NOTE: all functions arguments for adding rows are Into<T>
// where T is one of our new types.
// Further, functions taking multiple inputs of T are defined
// as X: Into<T>, X2: Into<T>, etc., allowing mix-and-match.

#[test]
fn test_adding_edge() {
add_row_without_metadata!(edges, add_edge, 0.1, 0.5, 0, 1); // left, right, parent, child
{
let tables = add_row_without_metadata!(edges, add_edge, 0.1, 0.5, 0, 1); // left, right, parent, child
compare_column_to_raw_column!(tables.edges(), left_slice, left_slice_raw);
compare_column_to_raw_column!(tables.edges(), right_slice, right_slice_raw);
compare_column_to_raw_column!(tables.edges(), parent_slice, parent_slice_raw);
compare_column_to_raw_column!(tables.edges(), child_slice, child_slice_raw);

compare_column_to_row!(tables.edges(), left_slice, left);
compare_column_to_row!(tables.edges(), right_slice, right);
compare_column_to_row!(tables.edges(), parent_slice, parent);
compare_column_to_row!(tables.edges(), child_slice, child);
}
add_row_without_metadata!(edges, add_edge, tskit::Position::from(0.1), 0.5, 0, 1); // left, right, parent, child
add_row_without_metadata!(edges, add_edge, 0.1, tskit::Position::from(0.5), 0, 1); // left, right, parent, child
add_row_without_metadata!(
Expand All @@ -105,8 +145,30 @@ mod test_adding_rows_without_metadata {

#[test]
fn test_adding_node() {
add_row_without_metadata!(nodes, add_node, 0, 0.1, -1, -1); // flags, time, population,
// individual
{
let tables =
add_row_without_metadata!(nodes, add_node, tskit::TSK_NODE_IS_SAMPLE, 0.1, -1, -1); // flags, time, population,
// individual
assert!(tables
.nodes()
.flags_slice()
.iter()
.zip(tables.nodes().flags_slice_raw().iter())
.all(|(a, b)| a.bits() == *b));
compare_column_to_raw_column!(tables.nodes(), time_slice, time_slice_raw);
compare_column_to_raw_column!(tables.nodes(), population_slice, population_slice_raw);
compare_column_to_raw_column!(tables.nodes(), individual_slice, individual_slice_raw);

assert!(tables
.nodes()
.flags_slice()
.iter()
.zip(tables.nodes().iter())
.all(|(c, r)| c == &r.flags));
compare_column_to_row!(tables.nodes(), time_slice, time);
compare_column_to_row!(tables.nodes(), population_slice, population);
compare_column_to_row!(tables.nodes(), individual_slice, individual);
}
add_row_without_metadata!(
nodes,
add_node,
Expand All @@ -120,7 +182,11 @@ mod test_adding_rows_without_metadata {
#[test]
fn test_adding_site() {
// No ancestral state
add_row_without_metadata!(sites, add_site, 2. / 3., None);
{
let tables = add_row_without_metadata!(sites, add_site, 2. / 3., None);
compare_column_to_raw_column!(tables.sites(), position_slice, position_slice_raw);
compare_column_to_row!(tables.sites(), position_slice, position);
}
add_row_without_metadata!(sites, add_site, tskit::Position::from(2. / 3.), None);
add_row_without_metadata!(sites, add_site, 2. / 3., Some(&[1_u8]));
add_row_without_metadata!(
Expand All @@ -136,14 +202,40 @@ mod test_adding_rows_without_metadata {
// site, node, parent mutation, time, derived_state
// Each value is a different Into<T> so we skip doing
// permutations
add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, None);
{
let tables = add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, None);
compare_column_to_raw_column!(tables.mutations(), node_slice, node_slice_raw);
compare_column_to_raw_column!(tables.mutations(), time_slice, time_slice_raw);
compare_column_to_raw_column!(tables.mutations(), site_slice, site_slice_raw);
compare_column_to_raw_column!(tables.mutations(), parent_slice, parent_slice_raw);

compare_column_to_row!(tables.mutations(), node_slice, node);
compare_column_to_row!(tables.mutations(), time_slice, time);
compare_column_to_row!(tables.mutations(), site_slice, site);
compare_column_to_row!(tables.mutations(), parent_slice, parent);
}

add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, Some(&[23_u8]));
}

#[test]
fn test_adding_individual() {
// flags, location, parents
add_row_without_metadata!(individuals, add_individual, 0, None, None);
{
let tables = add_row_without_metadata!(individuals, add_individual, 0, None, None);
assert!(tables
.individuals()
.flags_slice()
.iter()
.zip(tables.individuals().flags_slice_raw().iter())
.all(|(a, b)| a.bits() == *b));
assert!(tables
.individuals()
.flags_slice()
.iter()
.zip(tables.individuals().iter())
.all(|(c, r)| c == &r.flags));
}
add_row_without_metadata!(
individuals,
add_individual,
Expand Down Expand Up @@ -179,7 +271,23 @@ mod test_adding_rows_without_metadata {
fn test_adding_migration() {
// migration table
// (left, right), node, (source, dest), time
add_row_without_metadata!(migrations, add_migration, (0., 1.), 0, (0, 1), 0.0);
{
let tables =
add_row_without_metadata!(migrations, add_migration, (0., 1.), 0, (0, 1), 0.0);
compare_column_to_raw_column!(tables.migrations(), left_slice, left_slice_raw);
compare_column_to_raw_column!(tables.migrations(), right_slice, right_slice_raw);
compare_column_to_raw_column!(tables.migrations(), node_slice, node_slice_raw);
compare_column_to_raw_column!(tables.migrations(), time_slice, time_slice_raw);
compare_column_to_raw_column!(tables.migrations(), source_slice, source_slice_raw);
compare_column_to_raw_column!(tables.migrations(), dest_slice, dest_slice_raw);

compare_column_to_row!(tables.migrations(), left_slice, left);
compare_column_to_row!(tables.migrations(), right_slice, right);
compare_column_to_row!(tables.migrations(), node_slice, node);
compare_column_to_row!(tables.migrations(), time_slice, time);
compare_column_to_row!(tables.migrations(), source_slice, source);
compare_column_to_row!(tables.migrations(), dest_slice, dest);
}
add_row_without_metadata!(
migrations,
add_migration,
Expand Down

0 comments on commit 7ed2b9f

Please sign in to comment.