diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb new file mode 100644 index 000000000..da2118da3 --- /dev/null +++ b/examples/python-runtime.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from graphdatascience import GraphDataScience" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "ENVIRONMENT = \"mlruntimedev\"\n", + "DBID = \"e6ba1b5c\"\n", + "PASSWORD = \"l4Co2Qa5GseW0sMropCvJo17laf6ZCq9vuAhiJrVW2c\"" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", + "gds.set_database(\"neo4j\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": "Uploading Nodes: 0%| | 0/2708 [00:00\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n
gds.remoteml.getTrainResult('model2')
0{'test_acc_mean': 0.8589511513710022, 'test_ac...
\n" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_result" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "cora = gds.graph.get('cora')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "predictions = gds.graph.nodeProperties.stream(cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": " nodeId features \\\n0 31336 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n1 1061127 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n2 1106406 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n3 13195 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n4 37879 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n... ... ... \n2703 1128975 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2704 1128977 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2705 1128978 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2706 117328 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2707 24043 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n\n model2Predictions \n0 0 \n1 1 \n2 2 \n3 2 \n4 3 \n... ... \n2703 5 \n2704 5 \n2705 5 \n2706 6 \n2707 0 \n\n[2708 rows x 3 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
nodeIdfeaturesmodel2Predictions
031336[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
11061127[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...1
21106406[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
313195[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
437879[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3
............
27031128975[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27041128977[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27051128978[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
2706117328[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...6
270724043[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
\n

2708 rows × 3 columns

\n
" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictions" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 4abd44247..e91c1702b 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -1,5 +1,6 @@ from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder +from .gnn.gnn_endpoints import GnnEndpoints from .graph.graph_endpoints import ( GraphAlphaEndpoints, GraphBetaEndpoints, @@ -32,7 +33,9 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints): +class DirectEndpoints( + DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints +): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/__init__.py b/graphdatascience/gnn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py new file mode 100644 index 000000000..ba1b7b2b7 --- /dev/null +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -0,0 +1,18 @@ +from ..caller_base import CallerBase +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +from .gnn_nc_runner import GNNNodeClassificationRunner + + +class GNNRunner(UncallableNamespace, IllegalAttrChecker): + @property + def nodeClassification(self) -> GNNNodeClassificationRunner: + return GNNNodeClassificationRunner( + self._query_runner, f"{self._namespace}.nodeClassification", self._server_version + ) + + +class GnnEndpoints(CallerBase): + @property + def gnn(self) -> GNNRunner: + return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py new file mode 100644 index 000000000..6e65e2337 --- /dev/null +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -0,0 +1,124 @@ +import json +from typing import Any, List +import time + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace + + +class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def make_graph_sage_config(self, graph_sage_config): + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, + "hidden_channels": 256, "learning_rate": 0.003} + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG + if graph_sage_config: + bad_keys = [] + for key in graph_sage_config: + if key not in GRAPH_SAGE_DEFAULT_CONFIG: + bad_keys.append(key) + if len(bad_keys) > 0: + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") + + final_sage_config.update(graph_sage_config) + return final_sage_config + + def get_logs(self, job_id: str, offset=0) -> "Series[Any]": # noqa: F821 + return self._query_runner.run_query( + "RETURN gds.remoteml.getLogs($job_id, $offset)", + params={ + "job_id": job_id, + "offset": offset + }).squeeze() + + def get_train_result(self, model_name: str) -> "Series[Any]": # noqa: F821 + return self._query_runner.run_query( + "RETURN gds.remoteml.getTrainResult($model_name)", + params={ + "model_name": model_name + }).squeeze() + + def train( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_property: str, + relationship_types: List[str], + target_node_label: str = None, + node_labels: List[str] = None, + graph_sage_config = None, + logging_interval: int = 5 + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { + "featureProperties": feature_properties, + "targetProperty": target_property, + "job_type": "train", + "nodeProperties": feature_properties + [target_property], + "relationshipTypes": relationship_types, + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) + } + + if target_node_label: + mlConfigMap["targetNodeLabel"] = target_node_label + if node_labels: + mlConfigMap["nodeLabels"] = node_labels + + mlTrainingConfig = json.dumps(mlConfigMap) + + # token and uri will be injected by arrow_query_runner + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", + params={ + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, + }, + ).jobId[0] + + received_logs = 0 + training_done = False + while not training_done: + for log in self.get_logs(job_id, offset=received_logs): + print(log) + received_logs += 1 + try: + self.get_train_result(model_name) + training_done = True + except Exception: + time.sleep(logging_interval) + + return job_id + + + + def predict( + self, + graph_name: str, + model_name: str, + mutateProperty: str, + predictedProbabilityProperty: str = None, + logging_interval = 5 + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { + "job_type": "predict", + "mutateProperty": mutateProperty + } + if predictedProbabilityProperty: + mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty + + mlTrainingConfig = json.dumps(mlConfigMap) + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", + params={ + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, + }, + ).jobId[0] + + received_logs = 0 + prediction_done = False + while not prediction_done: + for log in self.get_logs(job_id, offset=received_logs): + print(log) + received_logs += 1 + if log == "Prediction job completed": + prediction_done = True + if not prediction_done: + time.sleep(logging_interval) diff --git a/graphdatascience/ignored_server_endpoints.py b/graphdatascience/ignored_server_endpoints.py index 89ad9f0b2..d103a90c4 100644 --- a/graphdatascience/ignored_server_endpoints.py +++ b/graphdatascience/ignored_server_endpoints.py @@ -47,6 +47,7 @@ "gds.alpha.pipeline.nodeRegression.predict.stream", "gds.alpha.pipeline.nodeRegression.selectFeatures", "gds.alpha.pipeline.nodeRegression.train", + "gds.gnn.nc", "gds.similarity.cosine", "gds.similarity.euclidean", "gds.similarity.euclideanDistance", diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index cf648879a..eab64398c 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -29,6 +29,9 @@ def __init__( ): self._fallback_query_runner = fallback_query_runner self._server_version = server_version + # FIXME handle version were tls cert is given + self._auth = auth + self._uri = uri host, port_string = uri.split(":") @@ -39,8 +42,9 @@ def __init__( ) client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + self._auth_factory = AuthFactory(auth) if auth: - client_options["middleware"] = [AuthFactory(auth)] + client_options["middleware"] = [self._auth_factory] if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs @@ -129,6 +133,10 @@ def run_query( endpoint = "gds.beta.graph.relationships.stream" return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + elif "gds.upload.graph" in query: + # inject parameters + params["config"]["token"] = self._get_or_request_token() + params["config"]["arrowEndpoint"] = self._uri return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -184,6 +192,10 @@ def create_graph_constructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) + def _get_or_request_token(self) -> str: + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_factory.token() + class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: @@ -217,9 +229,14 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: self._factory = factory def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header: str = headers.get("Authorization", None) + auth_header: str = headers.get("authorization", None) if not auth_header: return + # authenticate_basic_token() returns a list. + # TODO We should take the first Bearer element here + if isinstance(auth_header, list): + auth_header = auth_header[0] + [auth_type, token] = auth_header.split(" ", 1) if auth_type == "Bearer": self._factory.set_token(token)