Skip to content

Commit e00de8b

Browse files
authored
Merge pull request #133 from neo4j/color-nodes-unhashable
Allow some unhashable types in `color_nodes`
2 parents 5b9da56 + 2562b5f commit e00de8b

File tree

2 files changed

+110
-3
lines changed

2 files changed

+110
-3
lines changed

python-wrapper/src/neo4j_viz/visualization_graph.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from collections.abc import Iterable
5-
from typing import Optional
5+
from typing import Any, Hashable, Optional
66

77
from IPython.display import HTML
88
from pydantic_extra_types.color import Color, ColorType
@@ -201,7 +201,8 @@ def color_nodes(self, property: str, colors: Optional[ColorsType] = None, overri
201201
Parameters
202202
----------
203203
property:
204-
The property of the nodes to use for coloring.
204+
The property of the nodes to use for coloring. The type of this property must be hashable, or be a
205+
list, set or dict containing only hashable types.
205206
colors:
206207
The colors to use for the nodes. If a dictionary is given, it should map from property to color.
207208
If an iterable is given, the colors are used in order.
@@ -238,7 +239,11 @@ def _color_nodes_iter(self, property: str, colors: Iterable[ColorType], override
238239
prop_to_color = {}
239240
colors_iter = iter(colors)
240241
for node in self.nodes:
241-
prop = getattr(node, property)
242+
raw_prop = getattr(node, property)
243+
try:
244+
prop = self._make_hashable(raw_prop)
245+
except ValueError:
246+
raise ValueError(f"Unable to color nodes by unhashable property type '{type(raw_prop)}'")
242247

243248
if prop not in prop_to_color:
244249
next_color = next(colors_iter, None)
@@ -263,3 +268,22 @@ def _color_nodes_iter(self, property: str, colors: Iterable[ColorType], override
263268
f"Ran out of colors for property '{property}'. {len(prop_to_color)} colors were needed, but only "
264269
f"{len(set(prop_to_color.values()))} were given, so reused colors"
265270
)
271+
272+
@staticmethod
273+
def _make_hashable(raw_prop: Any) -> Hashable:
274+
prop = raw_prop
275+
if isinstance(raw_prop, list):
276+
prop = tuple(raw_prop)
277+
elif isinstance(raw_prop, set):
278+
prop = frozenset(raw_prop)
279+
elif isinstance(raw_prop, dict):
280+
prop = tuple(sorted(raw_prop.items()))
281+
282+
try:
283+
hash(prop)
284+
except TypeError:
285+
raise ValueError(f"Unable to convert property '{raw_prop}' of type {type(raw_prop)} to a hashable type")
286+
287+
assert isinstance(prop, Hashable)
288+
289+
return prop

python-wrapper/tests/test_colors.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,86 @@ def test_color_nodes_default() -> None:
103103
assert VG.nodes[1].color == Color(neo4j_colors[1])
104104
assert VG.nodes[2].color == Color(neo4j_colors[1])
105105
assert VG.nodes[3].color == Color(neo4j_colors[2])
106+
107+
108+
def test_color_nodes_lists() -> None:
109+
nodes = [
110+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", labels=["Person"]),
111+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", labels=["Product"]),
112+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", labels=["Product"]),
113+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Both", labels=["Person", "Product"]),
114+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:2", caption="Both again", labels=["Person", "Product"]),
115+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:3", caption="Both reorder", labels=["Product", "Person"]),
116+
]
117+
118+
VG = VisualizationGraph(nodes=nodes, relationships=[])
119+
120+
VG.color_nodes("labels", ["#000000", "#00FF00", "#FF0000", "#0000FF"])
121+
122+
assert VG.nodes[0].color == Color("#000000")
123+
assert VG.nodes[1].color == Color("#00ff00")
124+
assert VG.nodes[2].color == Color("#00ff00")
125+
assert VG.nodes[3].color == Color("#ff0000")
126+
assert VG.nodes[4].color == Color("#ff0000")
127+
assert VG.nodes[5].color == Color("#0000ff")
128+
129+
130+
def test_color_nodes_sets() -> None:
131+
nodes = [
132+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", labels={"Person"}),
133+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", labels={"Product"}),
134+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", labels={"Product"}),
135+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Both", labels={"Person", "Product"}),
136+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:2", caption="Both again", labels={"Person", "Product"}),
137+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:3", caption="Both reorder", labels={"Product", "Person"}),
138+
]
139+
140+
VG = VisualizationGraph(nodes=nodes, relationships=[])
141+
142+
VG.color_nodes("labels", ["#000000", "#00FF00", "#FF0000", "#0000FF"])
143+
144+
assert VG.nodes[0].color == Color("#000000")
145+
assert VG.nodes[1].color == Color("#00ff00")
146+
assert VG.nodes[2].color == Color("#00ff00")
147+
assert VG.nodes[3].color == Color("#ff0000")
148+
assert VG.nodes[4].color == Color("#ff0000")
149+
assert VG.nodes[4].color == Color("#ff0000")
150+
151+
152+
def test_color_nodes_dicts() -> None:
153+
nodes = [
154+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", config={"age": 18}),
155+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:6", caption="Product", config={"price": 100}),
156+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:11", caption="Product", config={"price": 100}),
157+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:1", caption="Product", config={"price": 1}),
158+
]
159+
160+
VG = VisualizationGraph(nodes=nodes, relationships=[])
161+
162+
VG.color_nodes("config", ["#000000", "#00FF00", "#FF0000", "#0000FF"])
163+
164+
assert VG.nodes[0].color == Color("#000000")
165+
assert VG.nodes[1].color == Color("#00ff00")
166+
assert VG.nodes[2].color == Color("#00ff00")
167+
assert VG.nodes[3].color == Color("#ff0000")
168+
169+
170+
def test_color_nodes_unhashable() -> None:
171+
nodes = [
172+
Node(
173+
id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0",
174+
caption="Person",
175+
config={"movies": ["Star Wars", "Star Trek"]},
176+
),
177+
]
178+
VG = VisualizationGraph(nodes=nodes, relationships=[])
179+
180+
with pytest.raises(ValueError, match="Unable to color nodes by unhashable property type '<class 'dict'>'"):
181+
VG.color_nodes("config", ["#000000"])
182+
183+
nodes = [
184+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", list_of_lists=[[1, 2], [3, 4]]),
185+
]
186+
VG = VisualizationGraph(nodes=nodes, relationships=[])
187+
with pytest.raises(ValueError, match="Unable to color nodes by unhashable property type '<class 'list'>'"):
188+
VG.color_nodes("list_of_lists", ["#000000"])

0 commit comments

Comments
 (0)