diff --git a/changelog.md b/changelog.md index 76f2e1c..e35e198 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ ## New features +* Allow visualization based only on relationship DataFrames, without specifying node DataFrames in `from_dfs` ## Bug fixes diff --git a/docs/source/integration.rst b/docs/source/integration.rst index 12e9b49..15454c0 100644 --- a/docs/source/integration.rst +++ b/docs/source/integration.rst @@ -38,6 +38,7 @@ The ``from_dfs`` method takes two mandatory positional parameters: on corresponding nodes under that field name. Otherwise, the column name will be a key in each node's `properties` dictionary, that maps to the node's corresponding value in the column. + If the graph has no node properties, the nodes can be derived from the relationships DataFrame alone. * A Pandas ``DataFrame``, or iterable (eg. list) of DataFrames representing the relationships of the graph. The rows of the DataFrame(s) should represent the individual relationships, and the columns should represent the relationship IDs and attributes. diff --git a/python-wrapper/src/neo4j_viz/pandas.py b/python-wrapper/src/neo4j_viz/pandas.py index b201da8..694ebf6 100644 --- a/python-wrapper/src/neo4j_viz/pandas.py +++ b/python-wrapper/src/neo4j_viz/pandas.py @@ -13,18 +13,40 @@ def _from_dfs( - node_dfs: DFS_TYPE, + node_dfs: Optional[DFS_TYPE], rel_dfs: DFS_TYPE, node_radius_min_max: Optional[tuple[float, float]] = (3, 60), rename_properties: Optional[dict[str, str]] = None, ) -> VisualizationGraph: + relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties) + + if node_dfs is None: + has_size = False + node_ids = set() + for rel in relationships: + node_ids.add(rel.source) + node_ids.add(rel.target) + nodes = [Node(id=id) for id in node_ids] + else: + nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties) + + VG = VisualizationGraph(nodes=nodes, relationships=relationships) + + if node_radius_min_max is not None and has_size: + VG.resize_nodes(node_radius_min_max=node_radius_min_max) + + return VG + + +def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> tuple[list[Node], bool]: if isinstance(node_dfs, DataFrame): node_dfs_iter: Iterable[DataFrame] = [node_dfs] + elif node_dfs is None: + node_dfs_iter = [] else: node_dfs_iter = node_dfs all_node_field_aliases = Node.all_validation_aliases() - all_rel_field_aliases = Relationship.all_validation_aliases() has_size = True nodes = [] @@ -42,13 +64,18 @@ def _from_dfs( properties[key] = value nodes.append(Node(**top_level, properties=properties)) + return nodes, has_size + + +def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> list[Relationship]: + all_rel_field_aliases = Relationship.all_validation_aliases() if isinstance(rel_dfs, DataFrame): rel_dfs_iter: Iterable[DataFrame] = [rel_dfs] else: rel_dfs_iter = rel_dfs + relationships: list[Relationship] = [] - relationships = [] for rel_df in rel_dfs_iter: for _, row in rel_df.iterrows(): top_level = {} @@ -62,17 +89,11 @@ def _from_dfs( properties[key] = value relationships.append(Relationship(**top_level, properties=properties)) - - VG = VisualizationGraph(nodes=nodes, relationships=relationships) - - if node_radius_min_max is not None and has_size: - VG.resize_nodes(node_radius_min_max=node_radius_min_max) - - return VG + return relationships def from_dfs( - node_dfs: DFS_TYPE, + node_dfs: Optional[DFS_TYPE], rel_dfs: DFS_TYPE, node_radius_min_max: Optional[tuple[float, float]] = (3, 60), ) -> VisualizationGraph: @@ -85,8 +106,9 @@ def from_dfs( Parameters ---------- - node_dfs: Union[DataFrame, Iterable[DataFrame]] + node_dfs: Optional[Union[DataFrame, Iterable[DataFrame]]] DataFrame or iterable of DataFrames containing node data. + If None, the nodes will be created from the source and target node ids in the rel_dfs. rel_dfs: Union[DataFrame, Iterable[DataFrame]] DataFrame or iterable of DataFrames containing relationship data. node_radius_min_max : tuple[float, float], optional diff --git a/python-wrapper/tests/test_pandas.py b/python-wrapper/tests/test_pandas.py index 4432926..76fb77f 100644 --- a/python-wrapper/tests/test_pandas.py +++ b/python-wrapper/tests/test_pandas.py @@ -1,6 +1,7 @@ from pandas import DataFrame from pydantic_extra_types.color import Color +from neo4j_viz.node import Node from neo4j_viz.pandas import from_dfs @@ -45,6 +46,31 @@ def test_from_df() -> None: assert VG.relationships[1].properties == {"weight": 2.0} +def test_from_rel_dfs() -> None: + relationships = [ + DataFrame( + { + "source": [0, 1], + "target": [1, 0], + "caption": ["REL", "REL2"], + "weight": [1.0, 2.0], + } + ), + DataFrame( + { + "source": [2, 3], + "target": [1, 0], + "caption": ["REL", "REL2"], + "weight": [1.0, 2.0], + } + ), + ] + VG = from_dfs(None, relationships) + + assert len(VG.relationships) == 4 + assert VG.nodes == [Node(id=id) for id in [0, 1, 2, 3]] + + def test_from_dfs() -> None: nodes = [ DataFrame(