Skip to content

Commit 28b22ed

Browse files
committed
WIP
1 parent daf86bb commit 28b22ed

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

graphdatascience/gnn/__init__.py

Whitespace-only changes.

graphdatascience/gnn/gnn_endpoints.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .gnn_nc_runner import GNNNodeClassificationRunner
2+
from ..caller_base import CallerBase
3+
from ..error.illegal_attr_checker import IllegalAttrChecker
4+
from ..error.uncallable_namespace import UncallableNamespace
5+
6+
class GNNRunner(UncallableNamespace, IllegalAttrChecker):
7+
@property
8+
def nodeClassification(self) -> GNNNodeClassificationRunner:
9+
return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version)
10+
11+
class GnnEndpoints(CallerBase):
12+
@property
13+
def gnn(self) -> GNNRunner:
14+
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
15+
16+
17+

graphdatascience/gnn/gnn_nc_runner.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Any, List
2+
3+
from ..error.illegal_attr_checker import IllegalAttrChecker
4+
from ..error.uncallable_namespace import UncallableNamespace
5+
import json
6+
7+
8+
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
9+
def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str,
10+
target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
11+
configMap = {
12+
"feature_properties": feature_properties,
13+
"target_property": target_property,
14+
}
15+
if target_node_label:
16+
configMap["targetNodeLabel"] = target_node_label
17+
mlTrainingConfig = json.dumps(configMap)
18+
# TODO query avaiable node labels
19+
node_labels = ["Paper"] if not node_labels else node_labels
20+
self._query_runner.run_query(f"CALL gds.upload.graph({graph_name}, {{mlTrainingConfig: {mlTrainingConfig}, modelName: {model_name}, nodeLabels: {node_labels}}})")

graphdatascience/ignored_server_endpoints.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"gds.alpha.pipeline.nodeRegression.predict.stream",
4848
"gds.alpha.pipeline.nodeRegression.selectFeatures",
4949
"gds.alpha.pipeline.nodeRegression.train",
50+
"gds.gnn.nc",
5051
"gds.similarity.cosine",
5152
"gds.similarity.euclidean",
5253
"gds.similarity.euclideanDistance",

0 commit comments

Comments
 (0)