Skip to content

Commit f7a8e95

Browse files
stergionbbengfortlwgray
authored
SilhouetteVisualizer add support for more estimators (DistrictDataLabs#1294)
Signed-off-by: Benjamin Bengfort <[email protected]> Co-authored-by: Benjamin Bengfort <[email protected]> Co-authored-by: Larry Gray <[email protected]>
1 parent 7a3c94c commit f7a8e95

8 files changed

+113
-30
lines changed
Loading
Loading
Loading
Loading
Loading

tests/test_cluster/test_silhouette.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
import sys
2121
import pytest
2222
import matplotlib.pyplot as plt
23+
import numpy as np
2324

2425
from sklearn.datasets import make_blobs
2526
from sklearn.cluster import KMeans, MiniBatchKMeans
27+
from sklearn.cluster import SpectralClustering, AgglomerativeClustering
2628

2729
from unittest import mock
2830
from tests.base import VisualTestCase
2931

30-
from yellowbrick.datasets import load_nfl
3132
from yellowbrick.cluster.silhouette import SilhouetteVisualizer, silhouette_visualizer
3233

3334

@@ -53,7 +54,6 @@ def test_integrated_kmeans_silhouette(self):
5354
n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0
5455
)
5556

56-
5757
fig = plt.figure()
5858
ax = fig.add_subplot()
5959

@@ -62,7 +62,6 @@ def test_integrated_kmeans_silhouette(self):
6262
visualizer.finalize()
6363

6464
self.assert_images_similar(visualizer, remove_legend=True)
65-
6665

6766
@pytest.mark.xfail(sys.platform == "win32", reason="images not close on windows")
6867
def test_integrated_mini_batch_kmeans_silhouette(self):
@@ -84,7 +83,6 @@ def test_integrated_mini_batch_kmeans_silhouette(self):
8483
visualizer.finalize()
8584

8685
self.assert_images_similar(visualizer, remove_legend=True)
87-
8886

8987
@pytest.mark.skip(reason="no negative silhouette example available yet")
9088
def test_negative_silhouette_score(self):
@@ -103,7 +101,6 @@ def test_colormap_silhouette(self):
103101
n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0
104102
)
105103

106-
107104
fig = plt.figure()
108105
ax = fig.add_subplot()
109106

@@ -138,7 +135,7 @@ def test_colors_silhouette(self):
138135
visualizer.finalize()
139136

140137
self.assert_images_similar(visualizer, remove_legend=True)
141-
138+
142139
def test_colormap_as_colors_silhouette(self):
143140
"""
144141
Test no exceptions for modifying the colors in a silhouette visualizer
@@ -162,7 +159,7 @@ def test_colormap_as_colors_silhouette(self):
162159
3.2 if sys.platform == "win32" else 0.01
163160
) # Fails on AppVeyor with RMS 3.143
164161
self.assert_images_similar(visualizer, remove_legend=True, tol=tol)
165-
162+
166163
def test_quick_method(self):
167164
"""
168165
Test the quick method producing a valid visualization
@@ -177,29 +174,44 @@ def test_quick_method(self):
177174

178175
self.assert_images_similar(oz)
179176

180-
@pytest.mark.xfail(
181-
reason="""third test fails with AssertionError: Expected fit
182-
to be called once. Called 0 times."""
183-
)
184177
def test_with_fitted(self):
185178
"""
186179
Test that visualizer properly handles an already-fitted model
187180
"""
188-
X, y = load_nfl(return_dataset=True).to_numpy()
189-
190-
model = MiniBatchKMeans().fit(X, y)
181+
X, y = make_blobs(
182+
n_samples=100, n_features=5, centers=3, shuffle=False, random_state=112
183+
)
184+
model = MiniBatchKMeans().fit(X)
185+
labels = model.predict(X)
191186

192187
with mock.patch.object(model, "fit") as mockfit:
193188
oz = SilhouetteVisualizer(model)
194-
oz.fit(X, y)
189+
oz.fit(X)
195190
mockfit.assert_not_called()
196191

197192
with mock.patch.object(model, "fit") as mockfit:
198193
oz = SilhouetteVisualizer(model, is_fitted=True)
199-
oz.fit(X, y)
194+
oz.fit(X)
200195
mockfit.assert_not_called()
201196

202-
with mock.patch.object(model, "fit") as mockfit:
197+
with mock.patch.object(model, "fit_predict", return_value=labels) as mockfit:
203198
oz = SilhouetteVisualizer(model, is_fitted=False)
204-
oz.fit(X, y)
205-
mockfit.assert_called_once_with(X, y)
199+
oz.fit(X)
200+
mockfit.assert_called_once_with(X, None)
201+
202+
@pytest.mark.parametrize(
203+
"model",
204+
[SpectralClustering, AgglomerativeClustering],
205+
)
206+
def test_clusterer_without_predict(self, model):
207+
"""
208+
Assert that clustering estimators that don't implement
209+
a predict() method utilize fit_predict()
210+
"""
211+
X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]])
212+
try:
213+
visualizer = SilhouetteVisualizer(model(n_clusters=2))
214+
visualizer.fit(X)
215+
visualizer.finalize()
216+
except AttributeError:
217+
self.fail("could not use fit or fit_predict methods")

yellowbrick/cluster/silhouette.py

+82-11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,35 @@
2323

2424
from sklearn.metrics import silhouette_score, silhouette_samples
2525

26+
try:
27+
from sklearn.metrics.pairwise import _VALID_METRICS
28+
except ImportError:
29+
_VALID_METRICS = [
30+
"cityblock",
31+
"cosine",
32+
"euclidean",
33+
"l1",
34+
"l2",
35+
"manhattan",
36+
"braycurtis",
37+
"canberra",
38+
"chebyshev",
39+
"correlation",
40+
"dice",
41+
"hamming",
42+
"jaccard",
43+
"kulsinski",
44+
"mahalanobis",
45+
"minkowski",
46+
"rogerstanimoto",
47+
"russellrao",
48+
"seuclidean",
49+
"sokalmichener",
50+
"sokalsneath",
51+
"sqeuclidean",
52+
"yule",
53+
]
54+
2655
from yellowbrick.utils import check_fitted
2756
from yellowbrick.style import resolve_colors
2857
from yellowbrick.cluster.base import ClusteringScoreVisualizer
@@ -113,7 +142,6 @@ class SilhouetteVisualizer(ClusteringScoreVisualizer):
113142
"""
114143

115144
def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs):
116-
117145
# Initialize the visualizer bases
118146
super(SilhouetteVisualizer, self).__init__(
119147
estimator, ax=ax, is_fitted=is_fitted, **kwargs
@@ -130,23 +158,47 @@ def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs):
130158
def fit(self, X, y=None, **kwargs):
131159
"""
132160
Fits the model and generates the silhouette visualization.
161+
162+
Unlike other visualizers that use the score() method to draw the results, this
163+
visualizer errs on visualizing on fit since this is when the clusters are
164+
computed. This means that a predict call is required in fit (or a fit_predict)
165+
in order to produce the visualization.
133166
"""
134-
# TODO: decide to use this method or the score method to draw.
135-
# NOTE: Probably this would be better in score, but the standard score
136-
# is a little different and I'm not sure how it's used.
137167

168+
# If the estimator is not fitted, fit it; then call predict to get the labels
169+
# for computing the silhoutte score on. If the estimator is already fitted, then
170+
# attempt to predict the labels, but if the estimator is stateless, fit and
171+
# predict on the data specified. At the end of this block, no matter the fitted
172+
# state of the estimator and the method, we should have cluster labels for X.
138173
if not check_fitted(self.estimator, is_fitted_by=self.is_fitted):
139-
# Fit the wrapped estimator
140-
self.estimator.fit(X, y, **kwargs)
174+
if hasattr(self.estimator, "fit_predict"):
175+
labels = self.estimator.fit_predict(X, y, **kwargs)
176+
else:
177+
self.estimator.fit(X, y, **kwargs)
178+
labels = self.estimator.predict(X)
179+
else:
180+
if hasattr(self.estimator, "predict"):
181+
labels = self.estimator.predict(X)
182+
else:
183+
labels = self.estimator.fit_predict(X, y, **kwargs)
141184

142185
# Get the properties of the dataset
143186
self.n_samples_ = X.shape[0]
144-
self.n_clusters_ = self.estimator.n_clusters
187+
188+
# Compute the number of available clusters from the estimator
189+
if hasattr(self.estimator, "n_clusters"):
190+
self.n_clusters_ = self.estimator.n_clusters
191+
else:
192+
unique_labels = set(labels)
193+
n_noise_clusters = 1 if -1 in unique_labels else 0
194+
self.n_clusters_ = len(unique_labels) - n_noise_clusters
195+
196+
# Identify the distance metric to use for silhouette scoring
197+
metric = self._identify_silhouette_metric()
145198

146199
# Compute the scores of the cluster
147-
labels = self.estimator.predict(X)
148-
self.silhouette_score_ = silhouette_score(X, labels)
149-
self.silhouette_samples_ = silhouette_samples(X, labels)
200+
self.silhouette_score_ = silhouette_score(X, labels, metric=metric)
201+
self.silhouette_samples_ = silhouette_samples(X, labels, metric=metric)
150202

151203
# Draw the silhouette figure
152204
self.draw(labels)
@@ -185,7 +237,6 @@ def draw(self, labels):
185237
# For each cluster, plot the silhouette scores
186238
self.y_tick_pos_ = []
187239
for idx in range(self.n_clusters_):
188-
189240
# Collect silhouette scores for samples in the current cluster .
190241
values = self.silhouette_samples_[labels == idx]
191242
values.sort()
@@ -260,6 +311,26 @@ def finalize(self):
260311
# Show legend (Average Silhouette Score axis)
261312
self.ax.legend(loc="best")
262313

314+
def _identify_silhouette_metric(self):
315+
"""
316+
The Silhouette metric must be one of the distance options allowed by
317+
metrics.pairwise.pairwise_distances or a callable. This method attempts to
318+
discover a valid distance metric from the underlying estimator or returns
319+
"euclidean" by default.
320+
"""
321+
if hasattr(self.estimator, "metric"):
322+
if callable(self.estimator.metric):
323+
return self.estimator.metric
324+
325+
if self.estimator.metric in _VALID_METRICS:
326+
return self.estimator.metric
327+
328+
if hasattr(self.estimator, "affinity"):
329+
if self.estimator.affinity in _VALID_METRICS:
330+
return self.estimator.affinity
331+
332+
return "euclidean"
333+
263334

264335
##########################################################################
265336
## Quick Method

0 commit comments

Comments
 (0)