Skip to content

Commit 7410a6a

Browse files
brs96FlorentinDorazve
committed
Cleanup nc_runner
Co-authored-by: Florentin Dörre <[email protected]> Co-authored-by: Olga Razvenskaia <[email protected]>
1 parent 6cd2c54 commit 7410a6a

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

graphdatascience/gnn/gnn_nc_runner.py

+42-22
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,66 @@
1+
import json
12
from typing import Any, List
23

34
from ..error.illegal_attr_checker import IllegalAttrChecker
45
from ..error.uncallable_namespace import UncallableNamespace
5-
import json
66

77

88
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
9-
def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str,
10-
target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
11-
configMap = {
9+
def train(
10+
self,
11+
graph_name: str,
12+
model_name: str,
13+
feature_properties: List[str],
14+
target_property: str,
15+
target_node_label: str = None,
16+
node_labels: List[str] = None,
17+
) -> "Series[Any]": # noqa: F821
18+
mlConfigMap = {
1219
"featureProperties": feature_properties,
1320
"targetProperty": target_property,
1421
"job_type": "train",
15-
"nodeProperties": feature_properties + [target_property]
22+
"nodeProperties": feature_properties + [target_property],
1623
}
1724

1825
if target_node_label:
19-
configMap["targetNodeLabel"] = target_node_label
26+
mlConfigMap["targetNodeLabel"] = target_node_label
2027
if node_labels:
21-
configMap["nodeLabels"] = node_labels
28+
mlConfigMap["nodeLabels"] = node_labels
2229

23-
mlTrainingConfig = json.dumps(configMap)
30+
mlTrainingConfig = json.dumps(mlConfigMap)
2431

2532
# token and uri will be injected by arrow_query_runner
2633
self._query_runner.run_query(
27-
f"CALL gds.upload.graph($graph_name, $config)",
28-
params={"graph_name": graph_name, "config": {
29-
"mlTrainingConfig": mlTrainingConfig,
30-
"modelName": model_name
31-
}}
32-
)
33-
34+
"CALL gds.upload.graph($graph_name, $config)",
35+
params={
36+
"graph_name": graph_name,
37+
"config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name},
38+
},
39+
)
3440

35-
def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
36-
configMap = {
41+
def predict(
42+
self,
43+
graph_name: str,
44+
model_name: str,
45+
feature_properties: List[str],
46+
target_node_label: str = None,
47+
node_labels: List[str] = None,
48+
) -> "Series[Any]": # noqa: F821
49+
mlConfigMap = {
3750
"featureProperties": feature_properties,
3851
"job_type": "predict",
52+
"nodeProperties": feature_properties,
3953
}
4054
if target_node_label:
41-
configMap["targetNodeLabel"] = target_node_label
42-
mlTrainingConfig = json.dumps(configMap)
43-
# TODO query available node labels
44-
node_labels = ["Paper"] if not node_labels else node_labels
55+
mlConfigMap["targetNodeLabel"] = target_node_label
56+
if node_labels:
57+
mlConfigMap["nodeLabels"] = node_labels
58+
59+
mlTrainingConfig = json.dumps(mlConfigMap)
4560
self._query_runner.run_query(
46-
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})")
61+
"CALL gds.upload.graph($graph_name, $config)",
62+
params={
63+
"graph_name": graph_name,
64+
"config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name},
65+
},
66+
) # type: ignore

graphdatascience/query_runner/arrow_query_runner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def create_graph_constructor(
191191
return ArrowGraphConstructor(
192192
database, graph_name, self._flight_client, concurrency, undirected_relationship_types
193193
)
194-
194+
195195
def _get_or_request_token(self) -> str:
196196
self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1])
197197
return self._auth_factory.token()
@@ -232,7 +232,7 @@ def received_headers(self, headers: Dict[str, Any]) -> None:
232232
auth_header: str = headers.get("authorization", None)
233233
if not auth_header:
234234
return
235-
# authenticate_basic_token() returns a list.
235+
# authenticate_basic_token() returns a list.
236236
# TODO We should take the first Bearer element here
237237
if isinstance(auth_header, list):
238238
auth_header = auth_header[0]

0 commit comments

Comments
 (0)