Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poisson context likelihood and better cli flexibility #126

Merged
merged 8 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 177 additions & 121 deletions gctree/branching_processes.py

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions gctree/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def isotype_add(forest):
mutability_file=args.mutability,
substitution_file=args.substitution,
chain_split=args.chain_split,
branching_process_ranking_coeff=args.branching_process_ranking_coeff,
use_old_mut_parsimony=args.use_old_mut_parsimony,
)

if args.verbose:
Expand Down Expand Up @@ -610,6 +612,16 @@ def get_parser():
"See a file excerpt in the documentation for :meth:`mutation_model.MutationModel`."
),
)
parser_infer.add_argument(
"--branching_process_ranking_coeff",
type=float,
default=-1,
help=(
"Coefficient used for branching process likelihood, when ranking trees by a linear "
"combination of traits. This value will be ignored if `--ranking_coeffs` argument is not "
"also provided."
),
)
parser_infer.add_argument(
"--ranking_coeffs",
type=float,
Expand All @@ -623,6 +635,16 @@ def get_parser():
"isotype parsimony, and mutability parsimony in that order."
),
)
parser_infer.add_argument(
"--use_old_mut_parsimony",
action="store_true",
help=(
"Use old mutability parsimony instead of poisson context likelihood. Not recommended "
"unless attempting to reproduce results from older versions of gctree. "
"This argument will have no effect unless an S5F model is provided with the arguments "
"`--mutability` and `--substitution`."
),
)
parser_infer.add_argument(
"--summarize_forest",
action="store_true",
Expand Down
19 changes: 11 additions & 8 deletions gctree/isotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def explode_idmap(
return newidmap


def _isotype_dagfuncs() -> hdag.utils.AddFuncDict:
def _isotype_dagfuncs() -> hdag.utils.HistoryDagFilter:
"""Return functions for filtering by isotype parsimony score on the history
DAG.

Expand Down Expand Up @@ -435,13 +435,16 @@ def edge_weight_func(n1: hdag.HistoryDagNode, n2: hdag.HistoryDagNode):
n1iso = list(n1isos.keys())[0]
return int(sum(isotype_distance(n1iso, n2iso) for n2iso in n2isos.keys()))

return hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": edge_weight_func,
"accum_func": sum,
},
name="Isotype Pars.",
return hdag.utils.HistoryDagFilter(
hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": edge_weight_func,
"accum_func": sum,
},
name="Isotype Pars.",
),
min,
)


Expand Down
142 changes: 102 additions & 40 deletions gctree/mutation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import historydag as hdag
from multiset import FrozenMultiset
from typing import Tuple, List, Callable, Optional
import itertools
import math


class MutationModel:
Expand Down Expand Up @@ -129,20 +131,25 @@ def mutability(self, kmer: str) -> Tuple[np.float64, np.float64]:
"sequence {} must contain only characters A, C, G, T, or N".format(kmer)
)

mutabilities_to_average, substitutions_to_average = zip(
*[self.context_model[x] for x in MutationModel._disambiguate(kmer)]
)

average_mutability = np.mean(mutabilities_to_average)
average_substitution = {
b: sum(
substitution_dict[b] for substitution_dict in substitutions_to_average
cached = self.context_model.get(kmer, None)
if cached is None:
mutabilities_to_average, substitutions_to_average = zip(
*[self.context_model[x] for x in MutationModel._disambiguate(kmer)]
)
/ len(substitutions_to_average)
for b in "ACGT"
}

return average_mutability, average_substitution
average_mutability = np.mean(mutabilities_to_average)
average_substitution = {
b: sum(
substitution_dict[b]
for substitution_dict in substitutions_to_average
)
/ len(substitutions_to_average)
for b in "ACGT"
}
cached = average_mutability, average_substitution
self.context_model[kmer] = cached

return cached

def mutabilities(self, sequence: str) -> List[Tuple[np.float64, np.float64]]:
r"""Returns the mutability of a sequence at each site, along with
Expand Down Expand Up @@ -440,7 +447,7 @@ def _sequence_disambiguations(sequence, _accum=""):

def _mutability_dagfuncs(
*args, splits: List[int] = [], **kwargs
) -> hdag.utils.AddFuncDict:
) -> hdag.utils.HistoryDagFilter:
"""Return functions for counting mutability parsimony on the history DAG.

Mutability parsimony of a tree is the sum over all edges in the tree
Expand Down Expand Up @@ -478,36 +485,38 @@ def distance(node1, node2):
else:
return dist(node1.label.sequence, node2.label.sequence)

return hdag.utils.AddFuncDict(
{"start_func": lambda n: 0, "edge_weight_func": distance, "accum_func": sum},
name="Mut. Pars.",
return hdag.utils.HistoryDagFilter(
hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": distance,
"accum_func": sum,
},
name="Mut. Pars.",
),
min,
)


def _mutability_distance_precursors(
mutation_model: MutationModel, splits: List[int] = []
):
chunk_idxs = list(zip([0] + splits, splits + [None]))
# Caching could be moved to the MutationModel class instead.
context_model = mutation_model.context_model.copy()
k = mutation_model.k
h = k // 2
# Build all sequences with (when k=5) one or two Ns on either end
templates = [
("N" * left, "N" * (k - left - right), "N" * right)
for left in range(h + 1)
for right in range(h + 1)
if left != 0 or right != 0
]

kmers_to_compute = [
leftns + stub + rightns
for leftns, ambig_stub, rightns in templates
for stub in _sequence_disambiguations(ambig_stub)
]
# Cache all these mutabilities in context_model also
context_model.update(
{kmer: mutation_model.mutability(kmer) for kmer in kmers_to_compute}

h = mutation_model.k // 2

# Pads sequence with N's, including in the chain-split boundary to
# avoid unrelated sites from being treated as part of each others' context.

# Indices at which padding N's will be in sequences returned from add_ns.
# Does not include indices of last two N's.
padding_indices = set(
itertools.chain.from_iterable(
[
range(split + idx * h, split + (idx + 1) * h)
for idx, split in enumerate([0] + splits)
]
)
)

def add_ns(seq: str):
Expand Down Expand Up @@ -535,16 +544,26 @@ def sum_minus_logp(pairs: FrozenMultiset):
p_arr = [
mult
* (
np.log(context_model[mer][0])
+ np.log(context_model[mer][1][newbase])
np.log(mutation_model.mutability(mer)[0])
+ np.log(mutation_model.mutability(mer)[1][newbase])
)
for (mer, newbase), mult in pairs
]
return -sum(p_arr)
else:
return 0.0

return (mutpairs, sum_minus_logp)
def mutability_sum(parent_seq):
padded_seq = add_ns(parent_seq)
for idx in padding_indices:
assert padded_seq[idx] == "N"
return sum(
mutation_model.mutability(padded_seq[idx - h : idx + h + 1])[0]
for idx, _ in enumerate(padded_seq[:-h])
if idx not in padding_indices
)

return (mutpairs, sum_minus_logp, mutability_sum)


def _mutability_distance(mutation_model: MutationModel, splits=[]):
Expand All @@ -562,11 +581,54 @@ def _mutability_distance(mutation_model: MutationModel, splits=[]):

Note that, in particular, this function is not symmetric on its arguments.
"""
mutpairs, sum_minus_logp = _mutability_distance_precursors(
mutpairs, sum_minus_logp, _ = _mutability_distance_precursors(
mutation_model, splits=splits
)

def distance(seq1, seq2):
return sum_minus_logp(mutpairs(seq1, seq2))

return distance


def _context_poisson_likelihood(mutation_model: MutationModel, splits=[]):
mutpairs, sum_minus_logp, mutability_sum = _mutability_distance_precursors(
mutation_model, splits=splits
)

def distance(seq1, seq2):
subs = mutpairs(seq1, seq2)
sub_count = len(subs)
if sub_count == 0:
return 0
else:
mut_sum = mutability_sum(seq1)
substitution_sum = -sum_minus_logp(subs)
return (
substitution_sum
+ (sub_count * (math.log(sub_count) - math.log(mut_sum)))
- sub_count
)

return distance


def _context_poisson_likelihood_dagfuncs(*args, splits: List[int] = [], **kwargs):
mutation_model = MutationModel(*args, **kwargs)
distance = _context_poisson_likelihood(mutation_model, splits=splits)

return hdag.utils.HistoryDagFilter(
hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": lambda n1, n2: (
0
if n1.is_ua_node()
else distance(n1.label.sequence, n2.label.sequence)
),
"accum_func": sum,
},
name="LogContextLikelihood",
),
max,
)
8 changes: 8 additions & 0 deletions tests/smalltest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ export MPLBACKEND=agg
mkdir -p tests/smalltest_output
wget -O HS5F_Mutability.csv https://bitbucket.org/kleinstein/shazam/raw/ba4b30fc6791e2cfd5712e9024803c53b136e664/data-raw/HS5F_Mutability.csv
wget -O HS5F_Substitution.csv https://bitbucket.org/kleinstein/shazam/raw/ba4b30fc6791e2cfd5712e9024803c53b136e664/data-raw/HS5F_Substitution.csv

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv --ranking_coeffs 1 1 0 --use_old_mut_parsimony --branching_process_ranking_coeff 0

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv --ranking_coeffs .01 -1 0 --branching_process_ranking_coeff -1 --summarize_forest --tree_stats

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv
8 changes: 4 additions & 4 deletions tests/test_isotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_trim_byisotype():
for node in tdag.preorder():
if node.attr is not None:
node.attr["isotype"] = node._dp_data
kwargs = _isotype_dagfuncs()
c = tdag.weight_count(**kwargs)
dag_filter = _isotype_dagfuncs()
c = tdag.weight_count(**dag_filter)
key = min(c)
count = c[key]
tdag.trim_optimal_weight(**kwargs, optimal_func=min)
assert tdag.weight_count(**kwargs) == {key: count}
tdag.trim_optimal_weight(**dag_filter)
assert tdag.weight_count(**dag_filter) == {key: count}
36 changes: 36 additions & 0 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gctree.branching_processes as bp
import gctree.phylip_parse as pp
import gctree.utils as utils
import gctree.mutation_model as mm
from math import log

import numpy as np
from multiset import FrozenMultiset
Expand Down Expand Up @@ -198,3 +200,37 @@ def test_recursion_depth():
bp.CollapsedTree._max_ll_cache = {}
with np.errstate(all="raise"):
bp.CollapsedTree._ll_genotype(2, 500, 0.4, 0.6)


def test_context_likelihood():
# These files will be present if pytest is run through `make test`.
mutation_model = mm.MutationModel(
mutability_file="HS5F_Mutability.csv", substitution_file="HS5F_Substitution.csv"
)
log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[])

parent_seq = "AAGAAA"
child_seq = "AATCAA"

term1 = sum(
log(
mutation_model.mutability(fivemer)[0]
* mutation_model.mutability(fivemer)[1][target_base]
)
for fivemer, target_base in [("AAGAA", "T"), ("AGAAA", "C")]
)
sum_mutabilities = sum(
mutation_model.mutability(fivemer)[0]
for fivemer in ["NNAAG", "NAAGA", "AAGAA", "AGAAA", "GAAAN", "AAANN"]
)
true_val = term1 + 2 * log(2 / sum_mutabilities) - 2
assert true_val == log_likelihood(parent_seq, child_seq)

# Now test chain split:
parent_seq = parent_seq + parent_seq
child_seq = child_seq + child_seq
# At index 6, the second concatenated sequence starts.
log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[6])

true_val = 2 * term1 + 4 * log(4 / (2 * sum_mutabilities)) - 4
assert true_val == log_likelihood(parent_seq, child_seq)
Loading