Skip to content

Commit a7c6cbf

Browse files
adamnschFlorentinD
andcommitted
Warn when doing sampling in from_gds
Co-Authored-By: Florentin Dörre <[email protected]>
1 parent 3507e42 commit a7c6cbf

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

python-wrapper/src/neo4j_viz/gds.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from itertools import chain
45
from typing import Optional
56
from uuid import uuid4
@@ -99,6 +100,9 @@ def from_gds(
99100

100101
node_count = G.node_count()
101102
if node_count > max_node_count:
103+
warnings.warn(
104+
f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed"
105+
)
102106
sampling_ratio = float(max_node_count) / node_count
103107
sample_name = f"neo4j-viz_sample_{uuid4()}"
104108
G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)

python-wrapper/tests/test_gds.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
import pandas as pd
@@ -267,7 +268,13 @@ def test_from_gds_sample(gds: Any) -> None:
267268
from neo4j_viz.gds import from_gds
268269

269270
with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
270-
VG = from_gds(gds, G)
271+
with pytest.warns(
272+
UserWarning,
273+
match=re.escape(
274+
"The 'hello' projection's node count (11000) exceeds `max_node_count` (10000), so subsampling will be applied. Increase `max_node_count` if needed"
275+
),
276+
):
277+
VG = from_gds(gds, G)
271278

272279
assert len(VG.nodes) >= 9_500
273280
assert len(VG.nodes) <= 10_500

0 commit comments

Comments
 (0)