Skip to content

Commit 2e5e514

Browse files
authored
Merge pull request #165 from zhi-yi-huang/issue#164
Fixed issue#164
2 parents 0991c80 + 437b653 commit 2e5e514

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

causallearn/search/ScoreBased/GES.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def ges(X: ndarray, score_func: str = 'local_score_BIC', maxP: Optional[float] =
139139
np.where(G.graph[j, :] == Endpoint.TAIL.value)[0]) # neighbors of Xj
140140

141141
Ti = np.union1d(np.where(G.graph[:, i] != Endpoint.NULL.value)[0],
142-
np.where(G.graph[i, 0] != Endpoint.NULL.value)[0]) # adjacent to Xi
142+
np.where(G.graph[i, :] != Endpoint.NULL.value)[0]) # adjacent to Xi
143143

144144
NTi = np.setdiff1d(np.arange(N), Ti)
145145
T0 = np.intersect1d(Tj, NTi) # find the neighbours of Xj that are not adjacent to Xi

tests/TestGES.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
######################################### Test Notes ###########################################
13-
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
13+
# All the benchmark results of loaded files (e.g. "tests/TestData/benchmark_returned_results/")#
1414
# are obtained from the code of causal-learn as of commit #
1515
# https://github.com/cmu-phil/causal-learn/commit/b51d788 (07-08-2022). #
1616
# #
@@ -23,13 +23,13 @@
2323

2424

2525
BENCHMARK_TXTFILE_TO_MD5 = {
26-
"./TestData/data_linear_10.txt": "95a17e15038d4cade0845140b67c05a6",
27-
"./TestData/data_discrete_10.txt": "ccb51c6c1946d8524a8b29a49aef2cc4",
28-
"./TestData/graph.10.txt": "4970d4ecb8be999a82a665e5f5e0825b",
29-
"./TestData/test_ges_simulated_linear_gaussian_data.txt": "0d2490eeb9ee8ef3b18bf21d7e936e1e",
30-
"./TestData/test_ges_simulated_linear_gaussian_CPDAG.txt": "aa0146777186b07e56421ce46ed52914",
31-
"./TestData/benchmark_returned_results/linear_10_ges_local_score_BIC_none_none.txt": "3accb3673d2ccb4c110f3703d60fe702",
32-
"./TestData/benchmark_returned_results/discrete_10_ges_local_score_BDeu_none_none.txt": "eebd11747c1b927b2fdd048a55c8c3a5",
26+
"tests/TestData/data_linear_10.txt": "95a17e15038d4cade0845140b67c05a6",
27+
"tests/TestData/data_discrete_10.txt": "ccb51c6c1946d8524a8b29a49aef2cc4",
28+
"tests/TestData/graph.10.txt": "4970d4ecb8be999a82a665e5f5e0825b",
29+
"tests/TestData/test_ges_simulated_linear_gaussian_data.txt": "0d2490eeb9ee8ef3b18bf21d7e936e1e",
30+
"tests/TestData/test_ges_simulated_linear_gaussian_CPDAG.txt": "aa0146777186b07e56421ce46ed52914",
31+
"tests/TestData/benchmark_returned_results/linear_10_ges_local_score_BIC_none_none.txt": "3accb3673d2ccb4c110f3703d60fe702",
32+
"tests/TestData/benchmark_returned_results/discrete_10_ges_local_score_BDeu_none_none.txt": "eebd11747c1b927b2fdd048a55c8c3a5",
3333
}
3434

3535
INCONSISTENT_RESULT_GRAPH_ERRMSG = "Returned graph is inconsistent with the benchmark. Please check your code with the commit b51d788."
@@ -47,8 +47,8 @@ class TestGES(unittest.TestCase):
4747
# Load data from file "data_linear_10.txt". Run GES with local_score_BIC.
4848
def test_ges_load_linear_10_with_local_score_BIC(self):
4949
print('Now start test_ges_load_linear_10_with_local_score_BIC ...')
50-
data_path = "./TestData/data_linear_10.txt"
51-
truth_graph_path = "./TestData/graph.10.txt"
50+
data_path = "tests/TestData/data_linear_10.txt"
51+
truth_graph_path = "tests/TestData/graph.10.txt"
5252
data = np.loadtxt(data_path, skiprows=1)
5353
truth_dag = txt2generalgraph(truth_graph_path) # truth_dag is a GeneralGraph instance
5454
truth_cpdag = dag2cpdag(truth_dag)
@@ -58,7 +58,7 @@ def test_ges_load_linear_10_with_local_score_BIC(self):
5858
res_map = ges(data, score_func='local_score_BIC', maxP=None, parameters=None) # Run GES and obtain the estimated graph (res_map is Dict object,which contains the updated steps, the result causal graph and the result score.)
5959

6060
benchmark_returned_graph = np.loadtxt(
61-
"./TestData/benchmark_returned_results/linear_10_ges_local_score_BIC_none_none.txt")
61+
"tests/TestData/benchmark_returned_results/linear_10_ges_local_score_BIC_none_none.txt")
6262
assert np.all(res_map['G'].graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG
6363
shd = SHD(truth_cpdag, res_map['G'])
6464
print(f" ges(data, score_func='local_score_BIC', maxP=None, parameters=None)\tSHD: {shd.get_shd()} of {num_edges_in_truth}")
@@ -73,9 +73,9 @@ def test_ges_simulate_linear_gaussian_with_local_score_BIC(self):
7373
truth_DAG_directed_edges = {(0, 1), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)}
7474
truth_CPDAG_directed_edges = {(0, 3), (1, 3), (2, 3), (2, 4), (3, 4)}
7575
truth_CPDAG_undirected_edges = {(0, 1), (1, 2), (2, 1), (1, 0)}
76-
truth_CPDAG = np.loadtxt("./TestData/test_ges_simulated_linear_gaussian_CPDAG.txt")
76+
truth_CPDAG = np.loadtxt("tests/TestData/test_ges_simulated_linear_gaussian_CPDAG.txt")
7777

78-
###### Simulation configuration: code to generate "./TestData/test_ges_simulated_linear_gaussian_data.txt" ######
78+
###### Simulation configuration: code to generate "tests/TestData/test_ges_simulated_linear_gaussian_data.txt" ######
7979
# np.random.seed(42)
8080
# linear_weight_minabs, linear_weight_maxabs, linear_weight_netative_prob = 0.5, 0.9, 0.5
8181
# sample_size = 10000
@@ -89,10 +89,10 @@ def test_ges_simulate_linear_gaussian_with_local_score_BIC(self):
8989
# mixing_matrix = np.linalg.inv(np.eye(num_of_nodes) - adjacency_matrix)
9090
# exogenous_noise = np.random.normal(0, 1, (num_of_nodes, sample_size))
9191
# data = (mixing_matrix @ exogenous_noise).T
92-
# np.savetxt("./TestData/test_ges_simulated_linear_gaussian_data.txt", data)
93-
###### Simulation configuration: code to generate "./TestData/test_ges_simulated_linear_gaussian_data.txt" ######
92+
# np.savetxt("tests/TestData/test_ges_simulated_linear_gaussian_data.txt", data)
93+
###### Simulation configuration: code to generate "tests/TestData/test_ges_simulated_linear_gaussian_data.txt" ######
9494

95-
data = np.loadtxt("./TestData/test_ges_simulated_linear_gaussian_data.txt")
95+
data = np.loadtxt("tests/TestData/test_ges_simulated_linear_gaussian_data.txt")
9696

9797
# Run GES with default parameters: score_func='local_score_BIC', maxP=None, parameters=None
9898
res_map = ges(data, score_func='local_score_BIC', maxP=None, parameters=None)
@@ -105,8 +105,8 @@ def test_ges_simulate_linear_gaussian_with_local_score_BIC(self):
105105
# Load data from file "data_discrete_10.txt". Run GES with local_score_BDeu.
106106
def test_ges_load_discrete_10_with_local_score_BDeu(self):
107107
print('Now start test_ges_load_discrete_10_with_local_score_BDeu ...')
108-
data_path = "./TestData/data_discrete_10.txt"
109-
truth_graph_path = "./TestData/graph.10.txt"
108+
data_path = "tests/TestData/data_discrete_10.txt"
109+
truth_graph_path = "tests/TestData/graph.10.txt"
110110
data = np.loadtxt(data_path, skiprows=1)
111111
truth_dag = txt2generalgraph(truth_graph_path) # truth_dag is a GeneralGraph instance
112112
truth_cpdag = dag2cpdag(truth_dag)
@@ -115,7 +115,7 @@ def test_ges_load_discrete_10_with_local_score_BDeu(self):
115115
# Run GES with local_score_BDeu.
116116
res_map = ges(data, score_func='local_score_BDeu', maxP=None, parameters=None)
117117
benchmark_returned_graph = np.loadtxt(
118-
"./TestData/benchmark_returned_results/discrete_10_ges_local_score_BDeu_none_none.txt")
118+
"tests/TestData/benchmark_returned_results/discrete_10_ges_local_score_BDeu_none_none.txt")
119119
assert np.all(res_map['G'].graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG
120120
shd = SHD(truth_cpdag, res_map['G'])
121121
print(f" ges(data, score_func='local_score_BDeu', maxP=None, parameters=None)\tSHD: {shd.get_shd()} of {num_edges_in_truth}")

0 commit comments

Comments
 (0)