From 4b568059953d3cc9a60993ed35093c2773b6a1c2 Mon Sep 17 00:00:00 2001 From: Adam Schill Collberg Date: Fri, 23 May 2025 09:45:36 +0200 Subject: [PATCH] Allow `Driver` as argument to `from_neo4j` --- python-wrapper/src/neo4j_viz/neo4j.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python-wrapper/src/neo4j_viz/neo4j.py b/python-wrapper/src/neo4j_viz/neo4j.py index 8973aca..ed0c7c6 100644 --- a/python-wrapper/src/neo4j_viz/neo4j.py +++ b/python-wrapper/src/neo4j_viz/neo4j.py @@ -3,7 +3,7 @@ from typing import Optional, Union import neo4j.graph -from neo4j import Result +from neo4j import Driver, Result, RoutingControl from neo4j_viz.node import Node from neo4j_viz.relationship import Relationship @@ -11,11 +11,12 @@ def from_neo4j( - result: Union[neo4j.graph.Graph, Result], + result: Union[neo4j.graph.Graph, Result, Driver], size_property: Optional[str] = None, node_caption: Optional[str] = "labels", relationship_caption: Optional[str] = "type", node_radius_min_max: Optional[tuple[float, float]] = (3, 60), + row_limit: int = 10_000, ) -> VisualizationGraph: """ Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object. @@ -44,8 +45,16 @@ def from_neo4j( graph = result.graph() elif isinstance(result, neo4j.graph.Graph): graph = result + elif isinstance(result, Driver): + graph = result.execute_query( + f"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT {row_limit}", + routing_=RoutingControl.READ, + result_transformer_=Result.graph, + ) else: - raise ValueError(f"Invalid input type `{type(result)}`. Expected `neo4j.Graph` or `neo4j.Result`") + raise ValueError( + f"Invalid input type `{type(result)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`" + ) all_node_field_aliases = Node.all_validation_aliases() all_rel_field_aliases = Relationship.all_validation_aliases()