5
5
6
6
import networkx as nx
7
7
import numpy as np
8
- import pandas as pd
9
8
from tqdm import tqdm
10
9
11
10
from . import utility
12
11
13
12
14
13
class BrickGraph :
15
- def __init__ (self , bricked_ts , use_rule_two = True ):
14
+ def __init__ (self , bricked_ts , threshold ):
16
15
self .bricked_ts = bricked_ts
17
- self .from_to_set = set ()
18
- self .use_rule_two = use_rule_two
16
+ self .threshold = threshold
17
+ self .brick_graph = nx .DiGraph ()
18
+ self .freqs = utility .get_brick_frequencies (self .bricked_ts )
19
19
20
- # make an argument for in and out here, so we know how to split if it's labeled
21
- def up_vertex (self , edge_id , in_out ):
22
- if edge_id in self .labeled_bricks :
23
- if in_out == "in" :
24
- return 6 * edge_id + 4
25
- else :
26
- return 6 * edge_id + 5
27
- else :
28
- return 6 * edge_id
20
+ def find_odds (self , brick ):
21
+ return self .freqs [brick ] / (1 - self .freqs [brick ])
29
22
30
- def down_vertex (self , edge_id , in_out ):
31
- if edge_id in self .labeled_bricks :
32
- if in_out == "in" :
33
- return 6 * edge_id + 4
34
- else :
35
- return 6 * edge_id + 5
23
+ def log_odds (self , odds ):
24
+ if odds != 1 :
25
+ return np .log (odds ) * - 1
26
+ else :
27
+ return np .log (odds )
28
+
29
+ def add_edge_threshold (self , from_node , to_node , weight ):
30
+ if weight <= 0 :
31
+ weight = 0
32
+ if self .threshold is not None :
33
+ if weight < self .threshold :
34
+ self .brick_graph .add_edge (from_node , to_node , weight = weight )
36
35
else :
37
- return 6 * edge_id + 1
36
+ self . brick_graph . add_edge ( from_node , to_node , weight = weight )
38
37
39
- def left_vertex (self , edge_id , in_out ):
38
+ # make an argument for in and out here, so we know how to split if it's labeled
39
+ def up_vertex (self , edge_id , in_out ):
40
+ odds = self .find_odds (edge_id )
40
41
if edge_id in self .labeled_bricks :
41
42
if in_out == "in" :
42
- return 6 * edge_id + 4
43
+ return 4 * edge_id + 2 , odds
43
44
else :
44
- return 6 * edge_id + 5
45
+ return 4 * edge_id + 3 , odds
45
46
else :
46
- return 6 * edge_id + 2
47
+ return 4 * edge_id , odds
47
48
48
- def right_vertex (self , edge_id , in_out ):
49
+ def down_vertex (self , edge_id , in_out ):
50
+ odds = self .find_odds (edge_id )
49
51
if edge_id in self .labeled_bricks :
50
52
if in_out == "in" :
51
- return 6 * edge_id + 4
53
+ return 4 * edge_id + 2 , odds
52
54
else :
53
- return 6 * edge_id + 5
55
+ return 4 * edge_id + 3 , odds
54
56
else :
55
- return 6 * edge_id + 3
57
+ return 4 * edge_id + 1 , odds
56
58
57
59
def rule_one (self , edge , children , node_edge_dict , roots ):
58
60
"""
@@ -62,94 +64,65 @@ def rule_one(self, edge, children, node_edge_dict, roots):
62
64
focal_node = edge .child
63
65
for child in children :
64
66
assert node_edge_dict [child ] != node_edge_dict [focal_node ]
65
- self .from_to_set .add (
66
- (
67
- self .up_vertex (node_edge_dict [child ], "in" ),
68
- self .up_vertex (node_edge_dict [focal_node ], "out" ),
69
- )
70
- )
71
- self .from_to_set .add (
72
- (
73
- self .down_vertex (node_edge_dict [focal_node ], "in" ),
74
- self .down_vertex (node_edge_dict [child ], "out" ),
75
- )
67
+ # Up of child to up of parent
68
+ child_label , child_odds = self .up_vertex (node_edge_dict [child ], "out" )
69
+ parent_label , parent_odds = self .up_vertex (node_edge_dict [focal_node ], "in" )
70
+ weight = self .log_odds (child_odds / parent_odds )
71
+ self .add_edge_threshold (child_label , parent_label , weight )
72
+
73
+ # Down of parent to down of child
74
+ parent_label , parent_odds = self .down_vertex (
75
+ node_edge_dict [focal_node ], "out"
76
76
)
77
+ child_label , child_odds = self .down_vertex (node_edge_dict [child ], "in" )
78
+ weight = self .log_odds (child_odds / parent_odds )
79
+ self .add_edge_threshold (parent_label , child_label , weight )
80
+
77
81
# Connect focal brick to its parent brick
78
82
if edge .parent not in roots and focal_node not in roots :
79
83
assert node_edge_dict [focal_node ] != node_edge_dict [edge .parent ]
80
- self .from_to_set .add (
81
- (
82
- self .up_vertex (node_edge_dict [focal_node ], "in" ),
83
- self .up_vertex (node_edge_dict [edge .parent ], "out" ),
84
- )
84
+ child_label , child_odds = self .up_vertex (node_edge_dict [focal_node ], "out" )
85
+ parent_label , parent_odds = self .up_vertex (
86
+ node_edge_dict [edge .parent ], "in"
85
87
)
86
- self .from_to_set . add (
87
- (
88
- self . down_vertex ( node_edge_dict [ edge . parent ], "in" ),
89
- self .down_vertex (node_edge_dict [ focal_node ], "out" ),
90
- )
88
+ weight = self .log_odds ( child_odds / parent_odds )
89
+ self . add_edge_threshold ( child_label , parent_label , weight )
90
+
91
+ parent_label , parent_odds = self .down_vertex (
92
+ node_edge_dict [ edge . parent ], "out"
91
93
)
94
+ child_label , child_odds = self .down_vertex (node_edge_dict [focal_node ], "in" )
95
+ weight = self .log_odds (child_odds / parent_odds )
96
+ self .add_edge_threshold (parent_label , child_label , weight )
92
97
93
98
def rule_two (self , edge , siblings , node_edge_dict ):
94
99
# Rule 2: Connect focal brick to its siblings
95
100
if len (siblings ) > 1 :
96
- if len (siblings ) > 1 :
97
- for item in itertools .combinations (siblings , 2 ):
98
- self .from_to_set .add (
99
- (
100
- self .up_vertex (node_edge_dict [item [0 ]], "in" ),
101
- self .down_vertex (node_edge_dict [item [1 ]], "out" ),
102
- )
103
- )
104
- self .from_to_set .add (
105
- (
106
- self .up_vertex (node_edge_dict [item [1 ]], "in" ),
107
- self .down_vertex (node_edge_dict [item [0 ]], "out" ),
108
- )
109
- )
110
-
111
- def rule_three (self , edge , children , node_edge_dict , prev_edge_dict ):
112
- """
113
- Rule 3: Connect focal brick to other bricks which share a child haplotype
114
- across a recombination
115
- """
116
- if edge .child in prev_edge_dict :
117
- # r,d of left parent to r of right parent
118
- self .from_to_set .add (
119
- (
120
- self .right_vertex (prev_edge_dict [edge .child ], "in" ),
121
- self .right_vertex (node_edge_dict [edge .child ], "out" ),
101
+ for pair in itertools .combinations (siblings , 2 ):
102
+ left_brick_up , left_odds = self .up_vertex (
103
+ node_edge_dict [pair [0 ]], "out"
122
104
)
123
- )
124
- self .from_to_set .add (
125
- (
126
- self .down_vertex (prev_edge_dict [edge .child ], "in" ),
127
- self .right_vertex (node_edge_dict [edge .child ], "out" ),
105
+ right_brick_down , right_odds = self .down_vertex (
106
+ node_edge_dict [pair [1 ]], "in"
128
107
)
129
- )
130
- # l,d of right parent to l of left parent
131
- self .from_to_set .add (
132
- (
133
- self .left_vertex (node_edge_dict [edge .child ], "in" ),
134
- self .left_vertex (prev_edge_dict [edge .child ], "out" ),
108
+ weight = self .log_odds (left_odds * right_odds )
109
+ self .add_edge_threshold (left_brick_up , right_brick_down , weight )
110
+
111
+ right_brick_up , left_odds = self .up_vertex (
112
+ node_edge_dict [pair [1 ]], "out"
135
113
)
136
- )
137
- self .from_to_set .add (
138
- (
139
- self .down_vertex (node_edge_dict [edge .child ], "in" ),
140
- self .left_vertex (prev_edge_dict [edge .child ], "out" ),
114
+ left_brick_down , right_odds = self .down_vertex (
115
+ node_edge_dict [pair [0 ]], "in"
141
116
)
142
- )
117
+ weight = self .log_odds (left_odds * right_odds )
118
+ self .add_edge_threshold (right_brick_up , left_brick_down , weight )
143
119
144
120
def make_connections (self , edge , tree2 , node_edge_dict , index , prev_edge_dict ):
145
121
roots = tree2 .roots
146
122
children = tree2 .children (edge .child )
147
123
siblings = tree2 .children (edge .parent )
148
124
self .rule_one (edge , children , node_edge_dict , roots )
149
- if self .use_rule_two :
150
- self .rule_two (edge , siblings , node_edge_dict )
151
- if index != 0 :
152
- self .rule_three (edge , siblings , node_edge_dict , prev_edge_dict )
125
+ self .rule_two (edge , siblings , node_edge_dict )
153
126
154
127
def make_brick_graph (self ):
155
128
"""
@@ -174,17 +147,13 @@ def make_brick_graph(self):
174
147
assert len (self .unlabeled_bricks ) >= (
175
148
self .bricked_ts .num_edges - self .bricked_ts .num_mutations
176
149
), (len (bricks ), self .bricked_ts .num_mutations , len (self .unlabeled_bricks ))
177
- self .G = None
178
- self .l_in = []
179
- self .l_out = []
180
150
181
151
node_edge_dict = {}
182
152
183
153
# Rule Zero
184
- # For unlabeled nodes: connect left to up, right to up (within a brick)
154
+ # For unlabeled nodes: connect down to up (within a brick)
185
155
for brick in self .unlabeled_bricks :
186
- self .from_to_set .add ((6 * brick + 2 , 6 * brick ))
187
- self .from_to_set .add ((6 * brick + 3 , 6 * brick ))
156
+ self .brick_graph .add_edge (4 * brick + 1 , 4 * brick , weight = self .log_odds (1 ))
188
157
189
158
for index , (tree2 , (_ , edges_out , edges_in )) in tqdm (
190
159
enumerate (zip (self .bricked_ts .trees (), self .bricked_ts .edge_diffs ())),
@@ -207,20 +176,4 @@ def make_brick_graph(self):
207
176
prev_edge_dict ,
208
177
)
209
178
210
- # Create networkx graph
211
- df = pd .Datamuts_to_merge_dict = {
212
- "from" : [cur_set [0 ] for cur_set in self .from_to_set ],
213
- "to" : [cur_set [1 ] for cur_set in self .from_to_set ],
214
- }
215
- self .G = nx .from_pandas_edgelist (df , "from" , "to" , create_using = nx .DiGraph ())
216
- # Total number of nodes should be less than (2 * number of labeled nodes) +
217
- # (4 * number of unlabeled nodes)
218
- # TODO: check why this isn't equal
219
- assert self .G .number_of_nodes () <= (2 * len (self .labeled_bricks )) + 4 * len (
220
- self .unlabeled_bricks
221
- ), (
222
- self .G .number_of_nodes (),
223
- (2 * len (self .labeled_bricks )),
224
- 4 * len (self .unlabeled_bricks ),
225
- )
226
- return self .G
179
+ return self .brick_graph
0 commit comments