Skip to content

Commit 02d89af

Browse files
committed
Add small wrapper over run_cypher
1 parent 65b620a commit 02d89af

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas import Series
55

66
from ..error.illegal_attr_checker import IllegalAttrChecker
7+
from ..query_runner.arrow_query_runner import ArrowQueryRunner
78
from ..query_runner.query_runner import QueryRunner
89
from ..server_version.server_version import ServerVersion
910
from .graph_object import Graph
@@ -232,6 +233,71 @@ def project(
232233

233234
return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore
234235

236+
def run_project(
237+
self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None
238+
) -> Tuple[Graph, "Series[Any]"]:
239+
"""
240+
Run a Cypher projection.
241+
The provided query must end with a `RETURN gds.graph.project(...)` call.
242+
243+
Parameters
244+
----------
245+
query: str
246+
the Cypher projection query
247+
params: Dict[str, Any]
248+
parameters to the query
249+
database: str
250+
the database on which to run the query
251+
252+
Returns
253+
-------
254+
A tuple of the projected graph and statistics about the projection
255+
"""
256+
257+
return_clause = f"RETURN {self._namespace}"
258+
259+
return_index = query.rfind(return_clause)
260+
if return_index == -1:
261+
raise ValueError(f"Invalid query, the query must end with a `{return_clause}` clause: {query}")
262+
263+
return_index += len(return_clause)
264+
return_part = query[return_index:]
265+
266+
# Remove surrounding parentheses and whitespace
267+
right_paren = return_part.rfind(")") + 1
268+
return_part = return_part[:right_paren].strip("() \n\t")
269+
270+
graph_name = return_part.split(",", maxsplit=1)[0]
271+
graph_name = graph_name.strip()
272+
273+
if graph_name.startswith("$"):
274+
if params is None:
275+
raise ValueError(
276+
f"Invalid query, the query references parameter `{graph_name}` but no params were given"
277+
)
278+
279+
graph_name = graph_name[1:]
280+
graph_name = params[graph_name]
281+
else:
282+
# remove the quotes
283+
graph_name = graph_name.strip("'\"")
284+
285+
# remove possible `AS graph` from the end of the query
286+
end_of_query = return_index + right_paren
287+
query = query[:end_of_query]
288+
289+
# run_cypher
290+
qr = self._query_runner
291+
292+
# The Arrow query runner should not be used to execute arbitrary Cypher
293+
if isinstance(qr, ArrowQueryRunner):
294+
qr = qr.fallback_query_runner()
295+
296+
result = qr.run_query(query, params, database, False)
297+
result = result.squeeze()
298+
299+
return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore
300+
235301
def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
236302
if spec is None or spec is False:
237303
return []

graphdatascience/tests/unit/test_graph_cypher.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,38 @@
55
from graphdatascience.server_version.server_version import ServerVersion
66

77

8+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
9+
def test_run_project(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
10+
G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)")
11+
12+
assert G.name() == "gg"
13+
assert runner.last_params() == {}
14+
15+
assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)"
16+
17+
18+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
19+
def test_run_project_with_return_as(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
20+
G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t) AS graph")
21+
22+
assert G.name() == "gg"
23+
assert runner.last_params() == {}
24+
25+
assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)"
26+
27+
28+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
29+
def test_run_project_with_graph_name_parameter(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
30+
G, _ = gds.graph.cypher.run_project(
31+
"MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)", params={"graph_name": "gg"}
32+
)
33+
34+
assert G.name() == "gg"
35+
assert runner.last_params() == {"graph_name": "gg"}
36+
37+
assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)"
38+
39+
840
@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
941
def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
1042
G, _ = gds.graph.cypher.project("g")

0 commit comments

Comments
 (0)