diff --git a/python-wrapper/src/neo4j_viz/neo4j.py b/python-wrapper/src/neo4j_viz/neo4j.py index d5eeb40..3cdaa9b 100644 --- a/python-wrapper/src/neo4j_viz/neo4j.py +++ b/python-wrapper/src/neo4j_viz/neo4j.py @@ -3,8 +3,8 @@ from typing import Optional, Union import neo4j.graph -from neo4j import Result from pydantic import BaseModel, ValidationError +from neo4j import Driver, Result, RoutingControl from neo4j_viz.node import Node from neo4j_viz.relationship import Relationship @@ -20,14 +20,15 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) -> def from_neo4j( - result: Union[neo4j.graph.Graph, Result], + result: Union[neo4j.graph.Graph, Result, Driver], size_property: Optional[str] = None, node_caption: Optional[str] = "labels", relationship_caption: Optional[str] = "type", node_radius_min_max: Optional[tuple[float, float]] = (3, 60), + row_limit: int = 10_000, ) -> VisualizationGraph: """ - Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object. + Create a VisualizationGraph from a Neo4j Graph, Neo4j Result object or Neo4j Driver. All node and relationship properties will be included in the visualization graph. If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as @@ -36,8 +37,9 @@ def from_neo4j( Parameters ---------- - result : Union[neo4j.graph.Graph, Result] - Query result either in shape of a Graph or result. + result : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver] + Either a query result either in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in + which case a simple default query will be executed to retrieve the graph data. size_property : str, optional Property to use for node size, by default None. node_caption : str, optional @@ -47,14 +49,25 @@ def from_neo4j( node_radius_min_max : tuple[float, float], optional Minimum and maximum node radius, by default (3, 60). To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. + row_limit : int, optional + Maximum number of rows to return from the query, by default 10_000. + This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored. """ if isinstance(result, Result): graph = result.graph() elif isinstance(result, neo4j.graph.Graph): graph = result + elif isinstance(result, Driver): + graph = result.execute_query( + f"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT {row_limit}", + routing_=RoutingControl.READ, + result_transformer_=Result.graph, + ) else: - raise ValueError(f"Invalid input type `{type(result)}`. Expected `neo4j.Graph` or `neo4j.Result`") + raise ValueError( + f"Invalid input type `{type(result)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`" + ) all_node_field_aliases = Node.all_validation_aliases() all_rel_field_aliases = Relationship.all_validation_aliases()