Skip to content

Commit 702d361

Browse files
committed
Remove rule three; new reduction (dijkstra)
1 parent 2b1b106 commit 702d361

26 files changed

+445
-330
lines changed

.DS_Store

-6 KB
Binary file not shown.

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
ld_graph/_version.py
22
*.asv
3-
*.asv
3+
.DS_Store
4+
ld_graph/__pycache__

examples/estimate_precision_matrix.m

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
addpath(genpath('../precision'))
2+
path_prefix='~/Dropbox/Pouria/data/';
3+
4+
% Genotype matrix
5+
X=load([path_prefix, 'genomat'])';
6+
7+
% Weighted adjacency matrix
8+
import_weighted = 1;
9+
A_weighted=importGraph([path_prefix,'adjlist'], import_weighted);
10+
11+
% Empty rows/columns correspond to duplicate SNPs (on same brick as
12+
% another SNP)
13+
SNPs = find(any(A_weighted));
14+
X = X(:,SNPs);
15+
if any(X(:)==-1)
16+
error('Missing genotypes not supported')
17+
end
18+
A_weighted = A_weighted(SNPs,SNPs);
19+
[numHaplotypes, numSNPs] = size(X);
20+
allele_freq = mean(X);
21+
22+
% How many edges to retain
23+
desired_density = 0.2;
24+
25+
% Threshold weighted network to desired density, and add self-edges
26+
max_density = nnz(A_weighted) / length(A_weighted)^2;
27+
threshold = quantile(nonzeros(A_weighted),...
28+
max(0, 1 - desired_density / max_density));
29+
A = A_weighted + speye(numSNPs) > threshold;
30+
31+
% LD matrix for edges of A
32+
[ii,jj] = find(A);
33+
X = (X - mean(X));
34+
X = X./sqrt(mean(X.^2));
35+
R = arrayfun(@(i,j)dot(X(:,i),X(:,j)),ii,jj)/numHaplotypes;
36+
R = sparse(ii,jj,R);
37+
38+
39+
% estimated precision matrix
40+
maxReps = 1e3;
41+
convergenceTol=1e-6;
42+
precisionEstimate = speye(size(A));
43+
44+
[precisionEstimate, obj_val] = LDPrecision(R,'P0',precisionEstimate,'max_steps',maxReps,...
45+
'convergence_tol',convergenceTol,'printstuff',true);
46+
47+
% Plot stuff for SNPs with frequency above threshold
48+
AF_threshold = 0.05;
49+
plotting_script
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
21
common = allele_freq>AF_threshold;
3-
Rr=inv(omegaEst);
2+
Rr=inv(precisionEstimate);
43
Rrc=Rr(common,common);
5-
Rc=corr(X(:,common));
4+
Rc = corr(X(:,common));
65

76
MSE = mean((Rrc(:)-Rc(:)).^2) / mean((Rc(:)).^2);
87
fprintf('Percent mean squared difference: %f\n', MSE)
@@ -14,5 +13,5 @@
1413
common=find(common,200,'first');
1514
empty=repmat({''},1,length(common));
1615
subplot(2,2,3);imagesc(Rr(common,common));colormap(bluewhitered(256));caxis([-1 1]);set(gca,'XTick',[],'YTick',[]);title('Regularized covariance')
17-
subplot(2,2,4);imagesc(corr(X(:,common)));colormap(bluewhitered(256));caxis([-1 1]);set(gca,'XTick',[],'YTick',[]);title('Sample covariance')
16+
subplot(2,2,4);imagesc(Rc);colormap(bluewhitered(256));caxis([-1 1]);set(gca,'XTick',[],'YTick',[]);title('Sample covariance')
1817
subplot(2,2,2);imagesc(A(common,common)+0);colormap(bluewhitered(256));caxis([-1 1]);set(gca,'XTick',[],'YTick',[]);title('Graphical model')
-220 Bytes
Binary file not shown.
-2.39 KB
Binary file not shown.
-5.14 KB
Binary file not shown.
-1.65 KB
Binary file not shown.
-410 Bytes
Binary file not shown.
Binary file not shown.
-2.17 KB
Binary file not shown.
-738 Bytes
Binary file not shown.

ld_graph/bricks.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import tskit
55
from tqdm import tqdm
66

7+
from . import utility
8+
79

810
class Bricks:
9-
def __init__(
10-
self,
11-
ts,
12-
):
11+
def __init__(self, ts, add_dummy_bricks=True):
1312
self.ts = ts
13+
self.add_dummy_bricks = add_dummy_bricks
1414

1515
def bifurcate_edge(self, edge_child, interval, tables, current_edges):
1616
"""
@@ -61,7 +61,7 @@ def naive_split_edges(self, mode="leaf"):
6161
for tree, (interval, edges_out, edges_in) in tqdm(
6262
zip(trees, edge_diffs),
6363
desc="Brick tree sequence: iterate over edges",
64-
total=ts.num_trees,
64+
total=ts.num_trees - 1,
6565
):
6666
# Add edges coming out to new edge table
6767
for edge in edges_out:
@@ -122,4 +122,6 @@ def naive_split_edges(self, mode="leaf"):
122122
tables.sort()
123123

124124
new_ts = tables.tree_sequence()
125+
if self.add_dummy_bricks:
126+
new_ts = utility.add_dummy_bricks(new_ts)
125127
return new_ts

ld_graph/bricks_graph.py

+72-119
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,56 @@
55

66
import networkx as nx
77
import numpy as np
8-
import pandas as pd
98
from tqdm import tqdm
109

1110
from . import utility
1211

1312

1413
class BrickGraph:
15-
def __init__(self, bricked_ts, use_rule_two=True):
14+
def __init__(self, bricked_ts, threshold):
1615
self.bricked_ts = bricked_ts
17-
self.from_to_set = set()
18-
self.use_rule_two = use_rule_two
16+
self.threshold = threshold
17+
self.brick_graph = nx.DiGraph()
18+
self.freqs = utility.get_brick_frequencies(self.bricked_ts)
1919

20-
# make an argument for in and out here, so we know how to split if it's labeled
21-
def up_vertex(self, edge_id, in_out):
22-
if edge_id in self.labeled_bricks:
23-
if in_out == "in":
24-
return 6 * edge_id + 4
25-
else:
26-
return 6 * edge_id + 5
27-
else:
28-
return 6 * edge_id
20+
def find_odds(self, brick):
21+
return self.freqs[brick] / (1 - self.freqs[brick])
2922

30-
def down_vertex(self, edge_id, in_out):
31-
if edge_id in self.labeled_bricks:
32-
if in_out == "in":
33-
return 6 * edge_id + 4
34-
else:
35-
return 6 * edge_id + 5
23+
def log_odds(self, odds):
24+
if odds != 1:
25+
return np.log(odds) * -1
26+
else:
27+
return np.log(odds)
28+
29+
def add_edge_threshold(self, from_node, to_node, weight):
30+
if weight <= 0:
31+
weight = 0
32+
if self.threshold is not None:
33+
if weight < self.threshold:
34+
self.brick_graph.add_edge(from_node, to_node, weight=weight)
3635
else:
37-
return 6 * edge_id + 1
36+
self.brick_graph.add_edge(from_node, to_node, weight=weight)
3837

39-
def left_vertex(self, edge_id, in_out):
38+
# make an argument for in and out here, so we know how to split if it's labeled
39+
def up_vertex(self, edge_id, in_out):
40+
odds = self.find_odds(edge_id)
4041
if edge_id in self.labeled_bricks:
4142
if in_out == "in":
42-
return 6 * edge_id + 4
43+
return 4 * edge_id + 2, odds
4344
else:
44-
return 6 * edge_id + 5
45+
return 4 * edge_id + 3, odds
4546
else:
46-
return 6 * edge_id + 2
47+
return 4 * edge_id, odds
4748

48-
def right_vertex(self, edge_id, in_out):
49+
def down_vertex(self, edge_id, in_out):
50+
odds = self.find_odds(edge_id)
4951
if edge_id in self.labeled_bricks:
5052
if in_out == "in":
51-
return 6 * edge_id + 4
53+
return 4 * edge_id + 2, odds
5254
else:
53-
return 6 * edge_id + 5
55+
return 4 * edge_id + 3, odds
5456
else:
55-
return 6 * edge_id + 3
57+
return 4 * edge_id + 1, odds
5658

5759
def rule_one(self, edge, children, node_edge_dict, roots):
5860
"""
@@ -62,94 +64,65 @@ def rule_one(self, edge, children, node_edge_dict, roots):
6264
focal_node = edge.child
6365
for child in children:
6466
assert node_edge_dict[child] != node_edge_dict[focal_node]
65-
self.from_to_set.add(
66-
(
67-
self.up_vertex(node_edge_dict[child], "in"),
68-
self.up_vertex(node_edge_dict[focal_node], "out"),
69-
)
70-
)
71-
self.from_to_set.add(
72-
(
73-
self.down_vertex(node_edge_dict[focal_node], "in"),
74-
self.down_vertex(node_edge_dict[child], "out"),
75-
)
67+
# Up of child to up of parent
68+
child_label, child_odds = self.up_vertex(node_edge_dict[child], "out")
69+
parent_label, parent_odds = self.up_vertex(node_edge_dict[focal_node], "in")
70+
weight = self.log_odds(child_odds / parent_odds)
71+
self.add_edge_threshold(child_label, parent_label, weight)
72+
73+
# Down of parent to down of child
74+
parent_label, parent_odds = self.down_vertex(
75+
node_edge_dict[focal_node], "out"
7676
)
77+
child_label, child_odds = self.down_vertex(node_edge_dict[child], "in")
78+
weight = self.log_odds(child_odds / parent_odds)
79+
self.add_edge_threshold(parent_label, child_label, weight)
80+
7781
# Connect focal brick to its parent brick
7882
if edge.parent not in roots and focal_node not in roots:
7983
assert node_edge_dict[focal_node] != node_edge_dict[edge.parent]
80-
self.from_to_set.add(
81-
(
82-
self.up_vertex(node_edge_dict[focal_node], "in"),
83-
self.up_vertex(node_edge_dict[edge.parent], "out"),
84-
)
84+
child_label, child_odds = self.up_vertex(node_edge_dict[focal_node], "out")
85+
parent_label, parent_odds = self.up_vertex(
86+
node_edge_dict[edge.parent], "in"
8587
)
86-
self.from_to_set.add(
87-
(
88-
self.down_vertex(node_edge_dict[edge.parent], "in"),
89-
self.down_vertex(node_edge_dict[focal_node], "out"),
90-
)
88+
weight = self.log_odds(child_odds / parent_odds)
89+
self.add_edge_threshold(child_label, parent_label, weight)
90+
91+
parent_label, parent_odds = self.down_vertex(
92+
node_edge_dict[edge.parent], "out"
9193
)
94+
child_label, child_odds = self.down_vertex(node_edge_dict[focal_node], "in")
95+
weight = self.log_odds(child_odds / parent_odds)
96+
self.add_edge_threshold(parent_label, child_label, weight)
9297

9398
def rule_two(self, edge, siblings, node_edge_dict):
9499
# Rule 2: Connect focal brick to its siblings
95100
if len(siblings) > 1:
96-
if len(siblings) > 1:
97-
for item in itertools.combinations(siblings, 2):
98-
self.from_to_set.add(
99-
(
100-
self.up_vertex(node_edge_dict[item[0]], "in"),
101-
self.down_vertex(node_edge_dict[item[1]], "out"),
102-
)
103-
)
104-
self.from_to_set.add(
105-
(
106-
self.up_vertex(node_edge_dict[item[1]], "in"),
107-
self.down_vertex(node_edge_dict[item[0]], "out"),
108-
)
109-
)
110-
111-
def rule_three(self, edge, children, node_edge_dict, prev_edge_dict):
112-
"""
113-
Rule 3: Connect focal brick to other bricks which share a child haplotype
114-
across a recombination
115-
"""
116-
if edge.child in prev_edge_dict:
117-
# r,d of left parent to r of right parent
118-
self.from_to_set.add(
119-
(
120-
self.right_vertex(prev_edge_dict[edge.child], "in"),
121-
self.right_vertex(node_edge_dict[edge.child], "out"),
101+
for pair in itertools.combinations(siblings, 2):
102+
left_brick_up, left_odds = self.up_vertex(
103+
node_edge_dict[pair[0]], "out"
122104
)
123-
)
124-
self.from_to_set.add(
125-
(
126-
self.down_vertex(prev_edge_dict[edge.child], "in"),
127-
self.right_vertex(node_edge_dict[edge.child], "out"),
105+
right_brick_down, right_odds = self.down_vertex(
106+
node_edge_dict[pair[1]], "in"
128107
)
129-
)
130-
# l,d of right parent to l of left parent
131-
self.from_to_set.add(
132-
(
133-
self.left_vertex(node_edge_dict[edge.child], "in"),
134-
self.left_vertex(prev_edge_dict[edge.child], "out"),
108+
weight = self.log_odds(left_odds * right_odds)
109+
self.add_edge_threshold(left_brick_up, right_brick_down, weight)
110+
111+
right_brick_up, left_odds = self.up_vertex(
112+
node_edge_dict[pair[1]], "out"
135113
)
136-
)
137-
self.from_to_set.add(
138-
(
139-
self.down_vertex(node_edge_dict[edge.child], "in"),
140-
self.left_vertex(prev_edge_dict[edge.child], "out"),
114+
left_brick_down, right_odds = self.down_vertex(
115+
node_edge_dict[pair[0]], "in"
141116
)
142-
)
117+
weight = self.log_odds(left_odds * right_odds)
118+
self.add_edge_threshold(right_brick_up, left_brick_down, weight)
143119

144120
def make_connections(self, edge, tree2, node_edge_dict, index, prev_edge_dict):
145121
roots = tree2.roots
146122
children = tree2.children(edge.child)
147123
siblings = tree2.children(edge.parent)
148124
self.rule_one(edge, children, node_edge_dict, roots)
149-
if self.use_rule_two:
150-
self.rule_two(edge, siblings, node_edge_dict)
151-
if index != 0:
152-
self.rule_three(edge, siblings, node_edge_dict, prev_edge_dict)
125+
self.rule_two(edge, siblings, node_edge_dict)
153126

154127
def make_brick_graph(self):
155128
"""
@@ -174,17 +147,13 @@ def make_brick_graph(self):
174147
assert len(self.unlabeled_bricks) >= (
175148
self.bricked_ts.num_edges - self.bricked_ts.num_mutations
176149
), (len(bricks), self.bricked_ts.num_mutations, len(self.unlabeled_bricks))
177-
self.G = None
178-
self.l_in = []
179-
self.l_out = []
180150

181151
node_edge_dict = {}
182152

183153
# Rule Zero
184-
# For unlabeled nodes: connect left to up, right to up (within a brick)
154+
# For unlabeled nodes: connect down to up (within a brick)
185155
for brick in self.unlabeled_bricks:
186-
self.from_to_set.add((6 * brick + 2, 6 * brick))
187-
self.from_to_set.add((6 * brick + 3, 6 * brick))
156+
self.brick_graph.add_edge(4 * brick + 1, 4 * brick, weight=self.log_odds(1))
188157

189158
for index, (tree2, (_, edges_out, edges_in)) in tqdm(
190159
enumerate(zip(self.bricked_ts.trees(), self.bricked_ts.edge_diffs())),
@@ -207,20 +176,4 @@ def make_brick_graph(self):
207176
prev_edge_dict,
208177
)
209178

210-
# Create networkx graph
211-
df = pd.Datamuts_to_merge_dict = {
212-
"from": [cur_set[0] for cur_set in self.from_to_set],
213-
"to": [cur_set[1] for cur_set in self.from_to_set],
214-
}
215-
self.G = nx.from_pandas_edgelist(df, "from", "to", create_using=nx.DiGraph())
216-
# Total number of nodes should be less than (2 * number of labeled nodes) +
217-
# (4 * number of unlabeled nodes)
218-
# TODO: check why this isn't equal
219-
assert self.G.number_of_nodes() <= (2 * len(self.labeled_bricks)) + 4 * len(
220-
self.unlabeled_bricks
221-
), (
222-
self.G.number_of_nodes(),
223-
(2 * len(self.labeled_bricks)),
224-
4 * len(self.unlabeled_bricks),
225-
)
226-
return self.G
179+
return self.brick_graph

0 commit comments

Comments
 (0)