Skip to content

Commit

Permalink
Merge pull request #46 from molpopgen/add_important_missing_functions
Browse files Browse the repository at this point in the history
Add kc_distance and num_tracked_samples to Tree
  • Loading branch information
molpopgen authored Apr 12, 2021
2 parents f559825 + fb85ba9 commit f1887d5
Showing 1 changed file with 146 additions and 1 deletion.
147 changes: 146 additions & 1 deletion src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,22 @@ impl Tree {

fn new(ts: &TreeSequence, flags: TreeFlags) -> Result<Self, TskitError> {
let mut tree = Self::wrap(ts.consumed.nodes().num_rows(), flags);
let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
let mut rv =
unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
if rv < 0 {
return Err(TskitError::ErrorCode { code: rv });
}
// Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
rv = unsafe {
ll_bindings::tsk_tree_set_tracked_samples(
tree.as_mut_ptr(),
ts.num_samples() as u64,
tree.inner.samples,
)
};
}

handle_tsk_return_value!(rv, tree)
}

Expand Down Expand Up @@ -355,6 +370,39 @@ impl Tree {
false => Ok(b),
}
}

/// Get the number of samples below node `u`.
///
/// # Errors
///
/// * [`TskitError`] if [`TreeFlags::NO_SAMPLE_COUNTS`].
pub fn num_tracked_samples(&self, u: tsk_id_t) -> Result<u64, TskitError> {
let mut n = u64::MAX;
let np: *mut u64 = &mut n;
let code = unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u, np) };
handle_tsk_return_value!(code, n)
}

/// Calculate the average Kendall-Colijn (`K-C`) distance between
/// pairs of trees whose intervals overlap.
///
/// # Note
///
/// * [Citation](https://doi.org/10.1093/molbev/msw124)
///
/// # Parameters
///
/// * `lambda` specifies the relative weight of topology and branch length.
/// If `lambda` is 0, we only consider topology.
/// If `lambda` is 1, we only consider branch lengths.
pub fn kc_distance(&self, other: &Tree, lambda: f64) -> Result<f64, TskitError> {
let mut kc = f64::NAN;
let kcp: *mut f64 = &mut kc;
let code = unsafe {
ll_bindings::tsk_tree_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
};
handle_tsk_return_value!(code, kc)
}
}

impl streaming_iterator::StreamingIterator for Tree {
Expand Down Expand Up @@ -763,6 +811,7 @@ impl TreeSequence {
/// # Parameters
///
/// * `lambda` specifies the relative weight of topology and branch length.
/// See [`Tree::kc_distance`] for more details.
pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
let mut kc: f64 = f64::NAN;
let kcp: *mut f64 = &mut kc;
Expand All @@ -771,6 +820,11 @@ impl TreeSequence {
};
handle_tsk_return_value!(code, kc)
}

// FIXME: document
pub fn num_samples(&self) -> tsk_size_t {
unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }
}
}

#[cfg(test)]
Expand Down Expand Up @@ -799,6 +853,51 @@ mod test_trees {
tables.tree_sequence().unwrap()
}

fn make_small_table_collection_two_trees() -> TableCollection {
// The two trees are:
// 0
// +++
// | | 1
// | | +++
// 2 3 4 5

// 0
// +-+-+
// 1 |
// +-+-+ |
// 2 4 5 3

let mut tables = TableCollection::new(1000.).unwrap();
tables.add_node(0, 2.0, TSK_NULL, TSK_NULL).unwrap();
tables.add_node(0, 1.0, TSK_NULL, TSK_NULL).unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables.add_edge(500., 1000., 0, 1).unwrap();
tables.add_edge(0., 500., 0, 2).unwrap();
tables.add_edge(0., 1000., 0, 3).unwrap();
tables.add_edge(500., 1000., 1, 2).unwrap();
tables.add_edge(0., 1000., 1, 4).unwrap();
tables.add_edge(0., 1000., 1, 5).unwrap();
tables.full_sort().unwrap();
tables.build_index(0).unwrap();
tables
}

fn treeseq_from_small_table_collection_two_trees() -> TreeSequence {
let tables = make_small_table_collection_two_trees();
tables.tree_sequence().unwrap()
}

#[test]
fn test_create_treeseq_new_from_tables() {
let tables = make_small_table_collection();
Expand Down Expand Up @@ -877,18 +976,46 @@ mod test_trees {
}
}

#[test]
fn test_num_tracked_samples() {
let treeseq = treeseq_from_small_table_collection();
assert_eq!(treeseq.inner.num_samples, 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
if let Some(tree) = tree_iter.next() {
assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
}
}

#[should_panic]
#[test]
fn test_num_tracked_samples_not_tracking_samples() {
let treeseq = treeseq_from_small_table_collection();
assert_eq!(treeseq.inner.num_samples, 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
if let Some(tree) = tree_iter.next() {
assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
}
}

#[test]
fn test_iterate_samples() {
let tables = make_small_table_collection();
let treeseq = tables.tree_sequence().unwrap();

let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
if let Some(tree) = tree_iter.next() {
assert!(!tree.flags.contains(TreeFlags::NO_SAMPLE_COUNTS));
assert!(tree.flags.contains(TreeFlags::SAMPLE_LISTS));
let mut s = vec![];
for i in tree.samples(0).unwrap() {
s.push(i);
}
assert_eq!(s.len(), 2);
assert_eq!(s.len(), tree.num_tracked_samples(0).unwrap() as usize);
assert_eq!(s[0], 1);
assert_eq!(s[1], 2);

Expand All @@ -899,12 +1026,30 @@ mod test_trees {
}
assert_eq!(s.len(), 1);
assert_eq!(s[0], u);
assert_eq!(s.len(), tree.num_tracked_samples(u).unwrap() as usize);
}
} else {
panic!("Expected a tree");
}
}

#[test]
fn test_iterate_samples_two_trees() {
let treeseq = treeseq_from_small_table_collection_two_trees();
assert_eq!(treeseq.inner.num_trees, 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
while let Some(tree) = tree_iter.next() {
for n in tree.nodes(NodeTraversalOrder::Preorder) {
let mut nsamples = 0;
for _ in tree.samples(n).unwrap() {
nsamples += 1;
}
assert!(nsamples > 0);
assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
}
}
}

#[test]
fn test_kc_distance_naive_test() {
let ts1 = treeseq_from_small_table_collection();
Expand Down

0 comments on commit f1887d5

Please sign in to comment.