From cf63d86372ad33df7fcb2449df4ed4ffb1ca948a Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Wed, 14 Jun 2023 18:17:15 +0200 Subject: [PATCH 1/9] Add pythonic projection to Cypher mapper DSL --- graphdatascience/graph/graph_cypher_runner.py | 242 ++++++++++++++++++ graphdatascience/graph/graph_proc_runner.py | 6 + .../tests/unit/test_graph_cypher.py | 224 ++++++++++++++++ 3 files changed, 472 insertions(+) create mode 100644 graphdatascience/graph/graph_cypher_runner.py create mode 100644 graphdatascience/tests/unit/test_graph_cypher.py diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py new file mode 100644 index 000000000..2bacb4068 --- /dev/null +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -0,0 +1,242 @@ +from collections import namedtuple +from typing import Any, NamedTuple, Optional, Tuple + +from pandas import Series + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..query_runner.query_runner import QueryRunner +from ..server_version.server_version import ServerVersion +from .graph_object import Graph + + +class NodeProperty(NamedTuple): + name: str + property_key: str + default_value: Optional[Any] = None + + +class NodeProjection(NamedTuple): + name: str + source_label: str + properties: Optional[list[NodeProperty]] = None + + +class RelationshipProperty(NamedTuple): + name: str + property_key: str + default_value: Optional[Any] = None + + +class RelationshipProjection(NamedTuple): + name: str + source_type: str + properties: Optional[list[RelationshipProperty]] = None + + +class MatchPart(NamedTuple): + match: str = "" + source_where: str = "" + optional_match: str = "" + optional_where: str = "" + + def __str__(self) -> str: + return "\n".join( + part + for part in [ + self.match, + self.source_where, + self.optional_match, + self.optional_where, + ] + if part + ) + + +class MatchPattern(NamedTuple): + label_filter: str = "" + left_arrow: str = "" + type_filter: str = "" + right_arrow: str = "" + + def __str__(self) -> str: + return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})" + + +class GraphCypherRunner(IllegalAttrChecker): + def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None: + if server_version < ServerVersion(2, 4, 0): + raise ValueError("The new Cypher projection is only supported since GDS 2.4.0.") + super().__init__(query_runner, namespace, server_version) + + def project( + self, + graph_name: str, + *, + nodes: Any = None, + relationships: Any = None, + where: Optional[str] = None, + allow_disconnected_nodes: bool = False, + inverse: bool = False, + combine_labels_with: str = "OR", + **config: Any, + ) -> Tuple[Graph, "Series[Any]"]: + """ + Project a graph using Cypher projection. + + Parameters + ---------- + graph_name : str + The name of the graph to project. + nodes : Any + The nodes to project. If not specified, all nodes are projected. + relationships : Any + The relationships to project. If not specified, all relationships + are projected. + where : Optional[str] + A Cypher WHERE clause to filter the nodes and relationships to + project. + allow_disconnected_nodes : bool + Whether to allow disconnected nodes in the projected graph. + inverse : bool + Whether to project inverse relationships. The projected graph will + be configured as NATURAL. + combine_labels_with : str + Whether to combine node labels with AND or OR. The default is AND. + Allowed values are 'AND' and 'OR'. + **config : Any + Additional configuration for the projection. + """ + + query_params = {"graph_name": graph_name} + + data_config = {} + + nodes = self._node_projections_spec(nodes) + rels = self._rel_projections_spec(relationships) + + match_part = MatchPart() + match_pattern = MatchPattern( + left_arrow="<-" if inverse else "-", + right_arrow="-" if inverse else "->", + ) + + if nodes: + if len(nodes) == 1 or combine_labels_with == "AND": + match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}") + + projected_labels = [spec.name for spec in nodes] + data_config["sourceNodeLabels"] = projected_labels + data_config["targetNodeLabels"] = projected_labels + + elif combine_labels_with == "OR": + source_labels_filter = " OR ".join(f"source:{spec.source_label}" for spec in nodes) + target_labels_filter = " OR ".join(f"target:{spec.source_label}" for spec in nodes) + if allow_disconnected_nodes: + match_part = match_part._replace( + source_where=f"WHERE {source_labels_filter}", optional_where=f"WHERE {target_labels_filter}" + ) + else: + match_part = match_part._replace( + source_where=f"WHERE ({source_labels_filter}) AND ({target_labels_filter})" + ) + + data_config["sourceNodeLabels"] = "labels(source)" + data_config["targetNodeLabels"] = "labels(target)" + else: + raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}") + + if rels: + if len(rels) == 1: + rel_var = "" + data_config["relationshipType"] = rels[0].source_type + else: + rel_var = "rel" + data_config["relationshipTypes"] = "type(rel)" + match_pattern = match_pattern._replace( + type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]" + ) + + source = f"(source{match_pattern.label_filter})" + if allow_disconnected_nodes: + match_part = match_part._replace( + match=f"MATCH {source}", optional_match=f"OPTIONAL MATCH (source){match_pattern}" + ) + else: + match_part = match_part._replace(match=f"MATCH {source}{match_pattern}") + + match_part = str(match_part) + + args = ["$graph_name", "source", "target"] + + if data_config: + query_params["data_config"] = data_config + args += ["$data_config"] + + if config: + query_params["config"] = config + args += ["$config"] + + return_part = f"RETURN {self._namespace}({', '.join(args)})" + + query = "\n".join(part for part in [match_part, return_part] if part) + + print(query) + + result = self._query_runner.run_query_with_logging( + query, + query_params, + ).squeeze() + + return Graph(graph_name, self._query_runner, self._server_version), result + + def _node_projections_spec(self, spec: Any) -> list[NodeProjection]: + if spec is None or spec is False: + return [] + + if isinstance(spec, str): + spec = [spec] + + if isinstance(spec, list): + return [self._node_projection_spec(node) for node in spec] + + if isinstance(spec, dict): + return [self._node_projection_spec(node, name) for name, node in spec.items()] + + raise TypeError(f"Invalid node projection specification: {spec}") + + def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection: + if isinstance(spec, str): + return NodeProjection(name=name or spec, source_label=spec) + + raise TypeError(f"Invalid node projection specification: {spec}") + + def _node_properties_spec(self, properties: dict[str, Any]) -> list[NodeProperty]: + raise TypeError(f"Invalid node projection specification: {properties}") + + def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]: + if spec is None or spec is False: + return [] + + if isinstance(spec, str): + spec = [spec] + + if isinstance(spec, list): + return [self._rel_projection_spec(node) for node in spec] + + if isinstance(spec, dict): + return [self._rel_projection_spec(node, name) for name, node in spec.items()] + + raise TypeError(f"Invalid relationship projection specification: {spec}") + + def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> RelationshipProjection: + if isinstance(spec, str): + return RelationshipProjection(name=name or spec, source_type=spec) + + raise TypeError(f"Invalid relationship projection specification: {spec}") + + def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]: + raise TypeError(f"Invalid relationship projection specification: {properties}") + + # + # def estimate(self, *, nodes: Any, relationships: Any, **config: Any) -> "Series[Any]": + # pass diff --git a/graphdatascience/graph/graph_proc_runner.py b/graphdatascience/graph/graph_proc_runner.py index f7633dcd8..90c1d6763 100644 --- a/graphdatascience/graph/graph_proc_runner.py +++ b/graphdatascience/graph/graph_proc_runner.py @@ -27,6 +27,7 @@ from .graph_sample_runner import GraphSampleRunner from .graph_type_check import graph_type_check, graph_type_check_optional from .ogb_loader import OGBLLoader, OGBNLoader +from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner Strings = Union[str, List[str]] @@ -165,6 +166,11 @@ def project(self) -> GraphProjectRunner: self._namespace += ".project" return GraphProjectRunner(self._query_runner, self._namespace, self._server_version) + @property + def cypher(self) -> GraphCypherRunner: + self._namespace += ".project" + return GraphCypherRunner(self._query_runner, self._namespace, self._server_version) + @property def export(self) -> GraphExportRunner: self._namespace += ".export" diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py new file mode 100644 index 000000000..b96e2bb0a --- /dev/null +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -0,0 +1,224 @@ +import pytest +from pandas import DataFrame + +from .conftest import CollectingQueryRunner +from graphdatascience.graph_data_science import GraphDataScience +from graphdatascience.server_version.server_version import ServerVersion + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g") + + assert G.name() == "g" + assert runner.last_params() == dict(graph_name="g") + + assert ( + runner.last_query() + == """MATCH (source)-->(target) +RETURN gds.graph.project($graph_name, source, target)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == dict(graph_name="g") + + assert ( + runner.last_query() + == """MATCH (source) +OPTIONAL MATCH (source)-->(target) +RETURN gds.graph.project($graph_name, source, target)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_inverse_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", inverse=True) # TODO: or using orientation="INVERSE"? + + assert G.name() == "g" + assert runner.last_params() == dict(graph_name="g") + + assert ( + runner.last_query() + == """MATCH (source)<--(target) +RETURN gds.graph.project($graph_name, source, target)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_single_node_label(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A") + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]} + ) + + assert ( + runner.last_query() + == """MATCH (source:A)-->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_single_node_label(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]} + ) + + assert ( + runner.last_query() + == """MATCH (source:A) +OPTIONAL MATCH (source)-->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_single_node_label_alias(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=dict(Target="Label")) + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": ["Target"], "targetNodeLabels": ["Target"]} + ) + + assert ( + runner.last_query() + == """MATCH (source:Label)-->(target:Label) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_multiple_node_labels_and(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="AND") + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]} + ) + + assert ( + runner.last_query() + == """MATCH (source:A:B)-->(target:A:B) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_multiple_node_labels_and(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="AND", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]} + ) + + assert ( + runner.last_query() + == """MATCH (source:A:B) +OPTIONAL MATCH (source)-->(target:A:B) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR") + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"} + ) + + assert ( + runner.last_query() + == """MATCH (source)-->(target) +WHERE (source:A OR source:B) AND (target:A OR target:B) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"} + ) + + assert ( + runner.last_query() + == """MATCH (source) +WHERE source:A OR source:B +OPTIONAL MATCH (source)-->(target) +WHERE target:A OR target:B +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_single_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A", relationships="REL") + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", + data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, + ) + + assert ( + runner.last_query() + == """MATCH (source:A)-[:REL]->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_single_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A", relationships="REL", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", + data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, + ) + + assert ( + runner.last_query() + == """MATCH (source:A) +OPTIONAL MATCH (source)-[:REL]->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], relationships=["REL1", "REL2"]) + + assert G.name() == "g" + assert runner.last_params() == dict( + graph_name="g", + data_config={ + "sourceNodeLabels": "labels(source)", + "targetNodeLabels": "labels(target)", + "relationshipTypes": "type(rel)", + }, + ) + + assert ( + runner.last_query() + == """MATCH (source)-[rel:REL1|REL2]->(target) +WHERE (source:A OR source:B) AND (target:A OR target:B) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) From f7735c0362e5157bf6bf2b6d81aad8274aa301ed Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Thu, 15 Jun 2023 16:17:30 +0200 Subject: [PATCH 2/9] Move non-static data config from params into query --- graphdatascience/graph/graph_cypher_runner.py | 13 +++++- .../tests/unit/test_graph_cypher.py | 40 ++++++++----------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index 2bacb4068..ee32c228a 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -110,6 +110,7 @@ def project( query_params = {"graph_name": graph_name} data_config = {} + data_config_is_static = True nodes = self._node_projections_spec(nodes) rels = self._rel_projections_spec(relationships) @@ -142,6 +143,7 @@ def project( data_config["sourceNodeLabels"] = "labels(source)" data_config["targetNodeLabels"] = "labels(target)" + data_config_is_static = False else: raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}") @@ -152,6 +154,7 @@ def project( else: rel_var = "rel" data_config["relationshipTypes"] = "type(rel)" + data_config_is_static = False match_pattern = match_pattern._replace( type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]" ) @@ -169,8 +172,11 @@ def project( args = ["$graph_name", "source", "target"] if data_config: - query_params["data_config"] = data_config - args += ["$data_config"] + if data_config_is_static: + query_params["data_config"] = data_config + args += ["$data_config"] + else: + args += [self._render_map(data_config)] if config: query_params["config"] = config @@ -237,6 +243,9 @@ def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> Relatio def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]: raise TypeError(f"Invalid relationship projection specification: {properties}") + def _render_map(self, mapping: dict[str, Any]) -> str: + return "{" + ", ".join(f"{key}: {value}" for key, value in mapping.items()) + "}" + # # def estimate(self, *, nodes: Any, relationships: Any, **config: Any) -> "Series[Any]": # pass diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py index b96e2bb0a..4d9923a7b 100644 --- a/graphdatascience/tests/unit/test_graph_cypher.py +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -136,15 +136,14 @@ def test_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataSc G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR") assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"} - ) + assert runner.last_params() == dict(graph_name="g") - assert ( - runner.last_query() - == """MATCH (source)-->(target) + assert runner.last_query() == ( + """MATCH (source)-->(target) WHERE (source:A OR source:B) AND (target:A OR target:B) -RETURN gds.graph.project($graph_name, source, target, $data_config)""" +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target)})" ) @@ -153,17 +152,16 @@ def test_disconnected_nodes_multiple_node_labels_or(runner: CollectingQueryRunne G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR", allow_disconnected_nodes=True) assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"} - ) + assert runner.last_params() == dict(graph_name="g") - assert ( - runner.last_query() - == """MATCH (source) + assert runner.last_query() == ( + """MATCH (source) WHERE source:A OR source:B OPTIONAL MATCH (source)-->(target) WHERE target:A OR target:B -RETURN gds.graph.project($graph_name, source, target, $data_config)""" +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target)})" ) @@ -207,18 +205,14 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], relationships=["REL1", "REL2"]) assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", - data_config={ - "sourceNodeLabels": "labels(source)", - "targetNodeLabels": "labels(target)", - "relationshipTypes": "type(rel)", - }, - ) + assert runner.last_params() == dict(graph_name="g") assert ( runner.last_query() == """MATCH (source)-[rel:REL1|REL2]->(target) WHERE (source:A OR source:B) AND (target:A OR target:B) -RETURN gds.graph.project($graph_name, source, target, $data_config)""" +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target), " + "relationshipTypes: type(rel)})" ) From 143270fb333c2101762020ca29881786609434dc Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Thu, 15 Jun 2023 16:54:46 +0200 Subject: [PATCH 3/9] Add basic node property support --- graphdatascience/graph/graph_cypher_runner.py | 71 ++++++++++++++++--- .../tests/unit/test_graph_cypher.py | 31 ++++++++ 2 files changed, 94 insertions(+), 8 deletions(-) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index ee32c228a..39e349490 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from collections import defaultdict, namedtuple from typing import Any, NamedTuple, Optional, Tuple from pandas import Series @@ -62,6 +62,12 @@ def __str__(self) -> str: return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})" +class LabelPropertyMapping(NamedTuple): + label: str + property_key: str + default_value: Optional[Any] = None + + class GraphCypherRunner(IllegalAttrChecker): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None: if server_version < ServerVersion(2, 4, 0): @@ -121,6 +127,8 @@ def project( right_arrow="-" if inverse else "->", ) + label_mappings = defaultdict(list) + if nodes: if len(nodes) == 1 or combine_labels_with == "AND": match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}") @@ -147,14 +155,22 @@ def project( else: raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}") + for spec in nodes: + if spec.properties: + for prop in spec.properties: + label_mappings[spec.source_label].append( + LabelPropertyMapping(spec.source_label, prop.property_key, prop.default_value) + ) + + rel_var = "" if rels: if len(rels) == 1: - rel_var = "" data_config["relationshipType"] = rels[0].source_type else: rel_var = "rel" data_config["relationshipTypes"] = "type(rel)" data_config_is_static = False + match_pattern = match_pattern._replace( type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]" ) @@ -169,6 +185,24 @@ def project( match_part = str(match_part) + case_part = [] + if label_mappings: + with_rel = f", {rel_var}" if rel_var else "" + case_part = [f"WITH source, target{with_rel}"] + for kind in ["source", "target"]: + case_part.append("CASE") + + for label, mappings in label_mappings.items(): + mappings = ", ".join(f".{key.property_key}" for key in mappings) + when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mappings}}}]" + case_part.append(when_part) + + case_part.append(f"END AS {kind}NodeProperties") + + data_config["sourceNodeProperties"] = "sourceNodeProperties" + data_config["targetNodeProperties"] = "targetNodeProperties" + data_config_is_static = False + args = ["$graph_name", "source", "target"] if data_config: @@ -184,9 +218,7 @@ def project( return_part = f"RETURN {self._namespace}({', '.join(args)})" - query = "\n".join(part for part in [match_part, return_part] if part) - - print(query) + query = "\n".join(part for part in [match_part, *case_part, return_part] if part) result = self._query_runner.run_query_with_logging( query, @@ -208,16 +240,39 @@ def _node_projections_spec(self, spec: Any) -> list[NodeProjection]: if isinstance(spec, dict): return [self._node_projection_spec(node, name) for name, node in spec.items()] - raise TypeError(f"Invalid node projection specification: {spec}") + raise TypeError(f"Invalid node projections specification: {spec}") def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection: if isinstance(spec, str): return NodeProjection(name=name or spec, source_label=spec) + if name is None: + raise ValueError(f"Node projections with properties must use the dict syntax: {spec}") + + if isinstance(spec, dict): + properties = [self._node_properties_spec(prop, name) for name, prop in spec.items()] + return NodeProjection(name=name, source_label=name, properties=properties) + + if isinstance(spec, list): + properties = [self._node_properties_spec(prop) for prop in spec] + return NodeProjection(name=name, source_label=name, properties=properties) + raise TypeError(f"Invalid node projection specification: {spec}") - def _node_properties_spec(self, properties: dict[str, Any]) -> list[NodeProperty]: - raise TypeError(f"Invalid node projection specification: {properties}") + def _node_properties_spec(self, spec: Any, name: Optional[str] = None) -> NodeProperty: + if isinstance(spec, str): + return NodeProperty(name=name or spec, property_key=spec) + + if name is None: + raise ValueError(f"Node properties spec must be used with the dict syntax: {spec}") + + if spec is True: + return NodeProperty(name=name, property_key=name) + + if isinstance(spec, dict): + return NodeProperty(name=name, property_key=name, **spec) + + raise TypeError(f"Invalid node property specification: {spec}") def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]: if spec is None or spec is False: diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py index 4d9923a7b..baa4b2582 100644 --- a/graphdatascience/tests/unit/test_graph_cypher.py +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -216,3 +216,34 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien "targetNodeLabels: labels(target), " "relationshipTypes: type(rel)})" ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project( + "g", nodes=dict(L1=["prop1"], L2=["prop2", "prop3"], L3=dict(prop4=True, prop5=dict())) + ) + + assert G.name() == "g" + assert runner.last_params() == dict(graph_name="g") + + assert runner.last_query() == ( + """MATCH (source)-->(target) +WHERE (source:L1 OR source:L2 OR source:L3) AND (target:L1 OR target:L2 OR target:L3) +WITH source, target +CASE +WHEN 'L1' in labels(source) THEN [source {.prop1}] +WHEN 'L2' in labels(source) THEN [source {.prop2, .prop3}] +WHEN 'L3' in labels(source) THEN [source {.prop4, .prop5}] +END AS sourceNodeProperties +CASE +WHEN 'L1' in labels(target) THEN [target {.prop1}] +WHEN 'L2' in labels(target) THEN [target {.prop2, .prop3}] +WHEN 'L3' in labels(target) THEN [target {.prop4, .prop5}] +END AS targetNodeProperties +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target), " + "sourceNodeProperties: sourceNodeProperties, " + "targetNodeProperties: targetNodeProperties})" + ) From f6c969932e0a5a72f353b025ec662ae90081f675 Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Thu, 15 Jun 2023 18:32:36 +0200 Subject: [PATCH 4/9] [WIP] node properties alias --- graphdatascience/graph/graph_cypher_runner.py | 20 +++++++++---- .../tests/unit/test_graph_cypher.py | 28 ++++++++++++++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index 39e349490..4ad8aec80 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -185,6 +185,9 @@ def project( match_part = str(match_part) + print("nodes", nodes) + print("labels", label_mappings) + case_part = [] if label_mappings: with_rel = f", {rel_var}" if rel_var else "" @@ -263,14 +266,21 @@ def _node_properties_spec(self, spec: Any, name: Optional[str] = None) -> NodePr if isinstance(spec, str): return NodeProperty(name=name or spec, property_key=spec) - if name is None: - raise ValueError(f"Node properties spec must be used with the dict syntax: {spec}") + if isinstance(spec, dict): + name = spec.pop("name", name) + if name is None: + raise ValueError( + f"Node properties must specify either a name in the outer dict or by using the `name` key: {spec}" + ) + property_key = spec.pop("property_key", name) + + return NodeProperty(name=name, property_key=property_key, **spec) if spec is True: - return NodeProperty(name=name, property_key=name) + if name is None: + raise ValueError(f"Node properties spec must be used with the dict syntax: {spec}") - if isinstance(spec, dict): - return NodeProperty(name=name, property_key=name, **spec) + return NodeProperty(name=name, property_key=name) raise TypeError(f"Invalid node property specification: {spec}") diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py index baa4b2582..a30cf1953 100644 --- a/graphdatascience/tests/unit/test_graph_cypher.py +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -220,8 +220,11 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien @pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + # G, _ = gds.graph.cypher.project( + # "g", nodes=dict(L1=["prop1"], L2=["prop2", "prop3"], L3=dict(prop4=True, prop5=dict())) + # ) G, _ = gds.graph.cypher.project( - "g", nodes=dict(L1=["prop1"], L2=["prop2", "prop3"], L3=dict(prop4=True, prop5=dict())) + "g", nodes={"L1": ["prop1"], "L2": ["prop2", "prop3"], "L3": {"prop4": True, "prop5": {}}} ) assert G.name() == "g" @@ -247,3 +250,26 @@ def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) - "sourceNodeProperties: sourceNodeProperties, " "targetNodeProperties: targetNodeProperties})" ) + + +@pytest.mark.skip(reason="Not implemented yet") +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_node_properties_alias(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project( + "g", nodes=dict(A=dict(target_prop1="source_prop1", target_prop2=dict(property_key="source_prop2"))) + ) + + assert G.name() == "g" + assert runner.last_params() == dict(graph_name="g") + + assert runner.last_query() == ( + """MATCH (source:A)-->(target:A) +WITH source, target, """ + "[{target_prop1: source.source_prop1, target_prop1: source.source_prop2}] AS sourceNodeProperties" + """[{target_prop1: target.source_prop1, target_prop1: target.source_prop2}] AS targetNodeProperties + RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target), " + "sourceNodeProperties: sourceNodeProperties, " + "targetNodeProperties: targetNodeProperties})" + ) From 40365dfba92a8126fd9a3fa767f905dc7ec2dec3 Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Fri, 16 Jun 2023 13:41:29 +0200 Subject: [PATCH 5/9] Add proper types and type usages --- graphdatascience/graph/graph_cypher_runner.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index 4ad8aec80..acaae4472 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -1,5 +1,5 @@ -from collections import defaultdict, namedtuple -from typing import Any, NamedTuple, Optional, Tuple +from collections import defaultdict +from typing import Any, Dict, NamedTuple, Optional, Tuple from pandas import Series @@ -33,7 +33,7 @@ class RelationshipProjection(NamedTuple): properties: Optional[list[RelationshipProperty]] = None -class MatchPart(NamedTuple): +class MatchParts(NamedTuple): match: str = "" source_where: str = "" optional_match: str = "" @@ -113,15 +113,15 @@ def project( Additional configuration for the projection. """ - query_params = {"graph_name": graph_name} + query_params: Dict[str, Any] = {"graph_name": graph_name} - data_config = {} + data_config: Dict[str, Any] = {} data_config_is_static = True nodes = self._node_projections_spec(nodes) rels = self._rel_projections_spec(relationships) - match_part = MatchPart() + match_parts = MatchParts() match_pattern = MatchPattern( left_arrow="<-" if inverse else "-", right_arrow="-" if inverse else "->", @@ -141,11 +141,11 @@ def project( source_labels_filter = " OR ".join(f"source:{spec.source_label}" for spec in nodes) target_labels_filter = " OR ".join(f"target:{spec.source_label}" for spec in nodes) if allow_disconnected_nodes: - match_part = match_part._replace( + match_parts = match_parts._replace( source_where=f"WHERE {source_labels_filter}", optional_where=f"WHERE {target_labels_filter}" ) else: - match_part = match_part._replace( + match_parts = match_parts._replace( source_where=f"WHERE ({source_labels_filter}) AND ({target_labels_filter})" ) @@ -177,13 +177,13 @@ def project( source = f"(source{match_pattern.label_filter})" if allow_disconnected_nodes: - match_part = match_part._replace( + match_parts = match_parts._replace( match=f"MATCH {source}", optional_match=f"OPTIONAL MATCH (source){match_pattern}" ) else: - match_part = match_part._replace(match=f"MATCH {source}{match_pattern}") + match_parts = match_parts._replace(match=f"MATCH {source}{match_pattern}") - match_part = str(match_part) + match_part = str(match_parts) print("nodes", nodes) print("labels", label_mappings) @@ -196,8 +196,8 @@ def project( case_part.append("CASE") for label, mappings in label_mappings.items(): - mappings = ", ".join(f".{key.property_key}" for key in mappings) - when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mappings}}}]" + mapping_projection = ", ".join(f".{key.property_key}" for key in mappings) + when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mapping_projection}}}]" case_part.append(when_part) case_part.append(f"END AS {kind}NodeProperties") @@ -223,12 +223,10 @@ def project( query = "\n".join(part for part in [match_part, *case_part, return_part] if part) - result = self._query_runner.run_query_with_logging( - query, - query_params, - ).squeeze() + result = self._query_runner.run_query_with_logging(query, query_params) + result = result.squeeze() - return Graph(graph_name, self._query_runner, self._server_version), result + return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore def _node_projections_spec(self, spec: Any) -> list[NodeProjection]: if spec is None or spec is False: @@ -305,10 +303,10 @@ def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> Relatio raise TypeError(f"Invalid relationship projection specification: {spec}") - def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]: + def _rel_properties_spec(self, properties: Dict[str, Any]) -> list[RelationshipProperty]: raise TypeError(f"Invalid relationship projection specification: {properties}") - def _render_map(self, mapping: dict[str, Any]) -> str: + def _render_map(self, mapping: Dict[str, Any]) -> str: return "{" + ", ".join(f"{key}: {value}" for key, value in mapping.items()) + "}" # From 5d6bea0321df3f5c3dc2829a349d0367ada25e04 Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Fri, 16 Jun 2023 13:42:26 +0200 Subject: [PATCH 6/9] Use dict literal syntax over dict function --- .../tests/unit/test_graph_cypher.py | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py index a30cf1953..a16d231f1 100644 --- a/graphdatascience/tests/unit/test_graph_cypher.py +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -1,5 +1,4 @@ import pytest -from pandas import DataFrame from .conftest import CollectingQueryRunner from graphdatascience.graph_data_science import GraphDataScience @@ -11,7 +10,7 @@ def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: G, _ = gds.graph.cypher.project("g") assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert ( runner.last_query() @@ -25,7 +24,7 @@ def test_disconnected(runner: CollectingQueryRunner, gds: GraphDataScience) -> N G, _ = gds.graph.cypher.project("g", allow_disconnected_nodes=True) assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert ( runner.last_query() @@ -40,7 +39,7 @@ def test_inverse_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> G, _ = gds.graph.cypher.project("g", inverse=True) # TODO: or using orientation="INVERSE"? assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert ( runner.last_query() @@ -54,9 +53,10 @@ def test_single_node_label(runner: CollectingQueryRunner, gds: GraphDataScience) G, _ = gds.graph.cypher.project("g", nodes="A") assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]} - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]}, + } assert ( runner.last_query() @@ -70,9 +70,10 @@ def test_disconnected_nodes_single_node_label(runner: CollectingQueryRunner, gds G, _ = gds.graph.cypher.project("g", nodes="A", allow_disconnected_nodes=True) assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]} - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]}, + } assert ( runner.last_query() @@ -84,12 +85,13 @@ def test_disconnected_nodes_single_node_label(runner: CollectingQueryRunner, gds @pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) def test_single_node_label_alias(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: - G, _ = gds.graph.cypher.project("g", nodes=dict(Target="Label")) + G, _ = gds.graph.cypher.project("g", nodes={"Target": "Label"}) assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": ["Target"], "targetNodeLabels": ["Target"]} - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["Target"], "targetNodeLabels": ["Target"]}, + } assert ( runner.last_query() @@ -103,9 +105,10 @@ def test_multiple_node_labels_and(runner: CollectingQueryRunner, gds: GraphDataS G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="AND") assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]} - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]}, + } assert ( runner.last_query() @@ -119,9 +122,10 @@ def test_disconnected_nodes_multiple_node_labels_and(runner: CollectingQueryRunn G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="AND", allow_disconnected_nodes=True) assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", data_config={"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]} - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]}, + } assert ( runner.last_query() @@ -136,7 +140,7 @@ def test_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataSc G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR") assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert runner.last_query() == ( """MATCH (source)-->(target) @@ -152,7 +156,7 @@ def test_disconnected_nodes_multiple_node_labels_or(runner: CollectingQueryRunne G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR", allow_disconnected_nodes=True) assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert runner.last_query() == ( """MATCH (source) @@ -170,10 +174,10 @@ def test_single_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience G, _ = gds.graph.cypher.project("g", nodes="A", relationships="REL") assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", - data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, + } assert ( runner.last_query() @@ -187,10 +191,10 @@ def test_disconnected_nodes_single_multi_graph(runner: CollectingQueryRunner, gd G, _ = gds.graph.cypher.project("g", nodes="A", relationships="REL", allow_disconnected_nodes=True) assert G.name() == "g" - assert runner.last_params() == dict( - graph_name="g", - data_config={"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, - ) + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, + } assert ( runner.last_query() @@ -205,7 +209,7 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], relationships=["REL1", "REL2"]) assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert ( runner.last_query() @@ -220,15 +224,12 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien @pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: - # G, _ = gds.graph.cypher.project( - # "g", nodes=dict(L1=["prop1"], L2=["prop2", "prop3"], L3=dict(prop4=True, prop5=dict())) - # ) G, _ = gds.graph.cypher.project( "g", nodes={"L1": ["prop1"], "L2": ["prop2", "prop3"], "L3": {"prop4": True, "prop5": {}}} ) assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert runner.last_query() == ( """MATCH (source)-->(target) @@ -256,11 +257,11 @@ def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) - @pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) def test_node_properties_alias(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: G, _ = gds.graph.cypher.project( - "g", nodes=dict(A=dict(target_prop1="source_prop1", target_prop2=dict(property_key="source_prop2"))) + "g", nodes={"A": {"target_prop1": "source_prop1", "target_prop2": {"property_key": "source_prop2"}}} ) assert G.name() == "g" - assert runner.last_params() == dict(graph_name="g") + assert runner.last_params() == {"graph_name": "g"} assert runner.last_query() == ( """MATCH (source:A)-->(target:A) From 1c59a26945a0818f9c00f6c1af63fd77e402c836 Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Fri, 16 Jun 2023 13:44:32 +0200 Subject: [PATCH 7/9] Use 'if not in' over 'not if in' --- doc/tests/test_client_only_endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/tests/test_client_only_endpoints.py b/doc/tests/test_client_only_endpoints.py index b8f56e15b..7c8f86a5a 100644 --- a/doc/tests/test_client_only_endpoints.py +++ b/doc/tests/test_client_only_endpoints.py @@ -59,7 +59,7 @@ def find_covered_server_endpoints() -> List[str]: driver.close() - return [ep["name"] for ep in all_server_endpoints if not ep["name"] in IGNORED_SERVER_ENDPOINTS] + return [ep["name"] for ep in all_server_endpoints if ep["name"] not in IGNORED_SERVER_ENDPOINTS] def check_rst_files(endpoints: List[str]) -> None: From 65b620aef7cdc13ab1ee897a5164947906c8aeea Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Fri, 16 Jun 2023 13:45:13 +0200 Subject: [PATCH 8/9] Add returns documentation --- graphdatascience/graph/graph_cypher_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index acaae4472..f40a43eef 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -111,6 +111,10 @@ def project( Allowed values are 'AND' and 'OR'. **config : Any Additional configuration for the projection. + + Returns + ------- + A tuple of the projected graph and statistics about the projection """ query_params: Dict[str, Any] = {"graph_name": graph_name} From 02d89af8fa2602d79b5e7751f2fcb6a3b3a9d806 Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Fri, 16 Jun 2023 13:45:40 +0200 Subject: [PATCH 9/9] Add small wrapper over run_cypher --- graphdatascience/graph/graph_cypher_runner.py | 66 +++++++++++++++++++ .../tests/unit/test_graph_cypher.py | 32 +++++++++ 2 files changed, 98 insertions(+) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index f40a43eef..34dcf0a3f 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -4,6 +4,7 @@ from pandas import Series from ..error.illegal_attr_checker import IllegalAttrChecker +from ..query_runner.arrow_query_runner import ArrowQueryRunner from ..query_runner.query_runner import QueryRunner from ..server_version.server_version import ServerVersion from .graph_object import Graph @@ -232,6 +233,71 @@ def project( return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore + def run_project( + self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None + ) -> Tuple[Graph, "Series[Any]"]: + """ + Run a Cypher projection. + The provided query must end with a `RETURN gds.graph.project(...)` call. + + Parameters + ---------- + query: str + the Cypher projection query + params: Dict[str, Any] + parameters to the query + database: str + the database on which to run the query + + Returns + ------- + A tuple of the projected graph and statistics about the projection + """ + + return_clause = f"RETURN {self._namespace}" + + return_index = query.rfind(return_clause) + if return_index == -1: + raise ValueError(f"Invalid query, the query must end with a `{return_clause}` clause: {query}") + + return_index += len(return_clause) + return_part = query[return_index:] + + # Remove surrounding parentheses and whitespace + right_paren = return_part.rfind(")") + 1 + return_part = return_part[:right_paren].strip("() \n\t") + + graph_name = return_part.split(",", maxsplit=1)[0] + graph_name = graph_name.strip() + + if graph_name.startswith("$"): + if params is None: + raise ValueError( + f"Invalid query, the query references parameter `{graph_name}` but no params were given" + ) + + graph_name = graph_name[1:] + graph_name = params[graph_name] + else: + # remove the quotes + graph_name = graph_name.strip("'\"") + + # remove possible `AS graph` from the end of the query + end_of_query = return_index + right_paren + query = query[:end_of_query] + + # run_cypher + qr = self._query_runner + + # The Arrow query runner should not be used to execute arbitrary Cypher + if isinstance(qr, ArrowQueryRunner): + qr = qr.fallback_query_runner() + + result = qr.run_query(query, params, database, False) + result = result.squeeze() + + return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore + def _node_projections_spec(self, spec: Any) -> list[NodeProjection]: if spec is None or spec is False: return [] diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py index a16d231f1..532690ebb 100644 --- a/graphdatascience/tests/unit/test_graph_cypher.py +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -5,6 +5,38 @@ from graphdatascience.server_version.server_version import ServerVersion +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)") + + assert G.name() == "gg" + assert runner.last_params() == {} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project_with_return_as(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t) AS graph") + + assert G.name() == "gg" + assert runner.last_params() == {} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project_with_graph_name_parameter(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project( + "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)", params={"graph_name": "gg"} + ) + + assert G.name() == "gg" + assert runner.last_params() == {"graph_name": "gg"} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)" + + @pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: G, _ = gds.graph.cypher.project("g")