Skip to content

Commit 7bc6df0

Browse files
authored
Merge pull request #27 from lias-laboratory/26-rebase-the-test-suit-on-scikit-learn-guidelines
26 rebase the test suit on scikit learn guidelines. Updated test suite to match sklearn compliance guidelines. Update the minor version to reflect the change. New version is now 1.4.1
2 parents 77bd09b + 7e3cb1d commit 7bc6df0

File tree

4 files changed

+24
-21
lines changed

4 files changed

+24
-21
lines changed

.coverage

0 Bytes
Binary file not shown.

examples/plot_benchmark_custom.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,18 @@ def benchmark_radius_clustering():
184184
fig.suptitle("Benchmark of Radius Clustering Solvers", fontsize=16)
185185

186186
axs['time'].set_yscale('log') # Use logarithmic scale for better visibility
187-
for algo, algo_results in results.items():
188-
# Plot execution time
189-
axs['time'].plot(
190-
DATASETS.keys(),
191-
algo_results["time"],
192-
marker='o',
193-
label=algo,
194-
)
195-
# Plot number of clusters
187+
188+
algorithms = list(results.keys())
189+
dataset_names = list(DATASETS.keys())
190+
n_algos = len(algorithms)
191+
x_indices = np.arange(len(dataset_names)) # the label locations
192+
bar_width = 0.8 / n_algos # the width of the bars, with some padding
193+
194+
for i, algo in enumerate(algorithms):
195+
times = results[algo]["time"]
196+
# Calculate position for each bar in the group to center them
197+
position = x_indices - (n_algos * bar_width / 2) + (i * bar_width) + bar_width / 2
198+
axs['time'].bar(position, times, bar_width, label=algo)
196199

197200
for i, (name, (dataset, _)) in enumerate(DATASETS.items()):
198201
axs[name].bar(
@@ -207,14 +210,15 @@ def benchmark_radius_clustering():
207210
linestyle='--',
208211
)
209212
axs[name].set_title(name)
210-
axs[name].set_xlabel("Algorithms")
211213

212214
axs["iris"].set_ylabel("Number of clusters")
213215
axs["glass"].set_ylabel("Number of clusters")
214216

215217
axs['time'].set_title("Execution Time (log scale)")
216218
axs['time'].set_xlabel("Datasets")
217219
axs['time'].set_ylabel("Time (seconds)")
220+
axs['time'].set_xticks(x_indices)
221+
axs['time'].set_xticklabels(dataset_names)
218222
axs['time'].legend(title="Algorithms")
219223
plt.tight_layout()
220224
plt.show()

src/radius_clustering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from .radius_clustering import RadiusClustering
33

44
__all__ = ["RadiusClustering"]
5-
__version__ = "1.4.0"
5+
__version__ = "1.4.1"

tests/test_structural.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from logging import getLogger
2-
3-
logger = getLogger(__name__)
4-
logger.setLevel("INFO")
5-
1+
from sklearn.utils.estimator_checks import parametrize_with_checks
62
def test_import():
73
import radius_clustering as rad
84

95

106
def test_from_import():
117
from radius_clustering import RadiusClustering
128

13-
def test_check_estimator_api_consistency():
14-
from radius_clustering import RadiusClustering
15-
from sklearn.utils.estimator_checks import check_estimator
169

17-
# Check the API consistency of the RadiusClustering estimator
18-
check_estimator(RadiusClustering())
10+
from radius_clustering import RadiusClustering
11+
12+
@parametrize_with_checks([RadiusClustering()])
13+
def test_check_estimator_api_consistency(estimator, check, request):
14+
15+
"""Check the API consistency of the RadiusClustering estimator
16+
"""
17+
check(estimator)

0 commit comments

Comments
 (0)