23
23
24
24
from sklearn .metrics import silhouette_score , silhouette_samples
25
25
26
+ try :
27
+ from sklearn .metrics .pairwise import _VALID_METRICS
28
+ except ImportError :
29
+ _VALID_METRICS = [
30
+ "cityblock" ,
31
+ "cosine" ,
32
+ "euclidean" ,
33
+ "l1" ,
34
+ "l2" ,
35
+ "manhattan" ,
36
+ "braycurtis" ,
37
+ "canberra" ,
38
+ "chebyshev" ,
39
+ "correlation" ,
40
+ "dice" ,
41
+ "hamming" ,
42
+ "jaccard" ,
43
+ "kulsinski" ,
44
+ "mahalanobis" ,
45
+ "minkowski" ,
46
+ "rogerstanimoto" ,
47
+ "russellrao" ,
48
+ "seuclidean" ,
49
+ "sokalmichener" ,
50
+ "sokalsneath" ,
51
+ "sqeuclidean" ,
52
+ "yule" ,
53
+ ]
54
+
26
55
from yellowbrick .utils import check_fitted
27
56
from yellowbrick .style import resolve_colors
28
57
from yellowbrick .cluster .base import ClusteringScoreVisualizer
@@ -113,7 +142,6 @@ class SilhouetteVisualizer(ClusteringScoreVisualizer):
113
142
"""
114
143
115
144
def __init__ (self , estimator , ax = None , colors = None , is_fitted = "auto" , ** kwargs ):
116
-
117
145
# Initialize the visualizer bases
118
146
super (SilhouetteVisualizer , self ).__init__ (
119
147
estimator , ax = ax , is_fitted = is_fitted , ** kwargs
@@ -130,23 +158,47 @@ def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs):
130
158
def fit (self , X , y = None , ** kwargs ):
131
159
"""
132
160
Fits the model and generates the silhouette visualization.
161
+
162
+ Unlike other visualizers that use the score() method to draw the results, this
163
+ visualizer errs on visualizing on fit since this is when the clusters are
164
+ computed. This means that a predict call is required in fit (or a fit_predict)
165
+ in order to produce the visualization.
133
166
"""
134
- # TODO: decide to use this method or the score method to draw.
135
- # NOTE: Probably this would be better in score, but the standard score
136
- # is a little different and I'm not sure how it's used.
137
167
168
+ # If the estimator is not fitted, fit it; then call predict to get the labels
169
+ # for computing the silhoutte score on. If the estimator is already fitted, then
170
+ # attempt to predict the labels, but if the estimator is stateless, fit and
171
+ # predict on the data specified. At the end of this block, no matter the fitted
172
+ # state of the estimator and the method, we should have cluster labels for X.
138
173
if not check_fitted (self .estimator , is_fitted_by = self .is_fitted ):
139
- # Fit the wrapped estimator
140
- self .estimator .fit (X , y , ** kwargs )
174
+ if hasattr (self .estimator , "fit_predict" ):
175
+ labels = self .estimator .fit_predict (X , y , ** kwargs )
176
+ else :
177
+ self .estimator .fit (X , y , ** kwargs )
178
+ labels = self .estimator .predict (X )
179
+ else :
180
+ if hasattr (self .estimator , "predict" ):
181
+ labels = self .estimator .predict (X )
182
+ else :
183
+ labels = self .estimator .fit_predict (X , y , ** kwargs )
141
184
142
185
# Get the properties of the dataset
143
186
self .n_samples_ = X .shape [0 ]
144
- self .n_clusters_ = self .estimator .n_clusters
187
+
188
+ # Compute the number of available clusters from the estimator
189
+ if hasattr (self .estimator , "n_clusters" ):
190
+ self .n_clusters_ = self .estimator .n_clusters
191
+ else :
192
+ unique_labels = set (labels )
193
+ n_noise_clusters = 1 if - 1 in unique_labels else 0
194
+ self .n_clusters_ = len (unique_labels ) - n_noise_clusters
195
+
196
+ # Identify the distance metric to use for silhouette scoring
197
+ metric = self ._identify_silhouette_metric ()
145
198
146
199
# Compute the scores of the cluster
147
- labels = self .estimator .predict (X )
148
- self .silhouette_score_ = silhouette_score (X , labels )
149
- self .silhouette_samples_ = silhouette_samples (X , labels )
200
+ self .silhouette_score_ = silhouette_score (X , labels , metric = metric )
201
+ self .silhouette_samples_ = silhouette_samples (X , labels , metric = metric )
150
202
151
203
# Draw the silhouette figure
152
204
self .draw (labels )
@@ -185,7 +237,6 @@ def draw(self, labels):
185
237
# For each cluster, plot the silhouette scores
186
238
self .y_tick_pos_ = []
187
239
for idx in range (self .n_clusters_ ):
188
-
189
240
# Collect silhouette scores for samples in the current cluster .
190
241
values = self .silhouette_samples_ [labels == idx ]
191
242
values .sort ()
@@ -260,6 +311,26 @@ def finalize(self):
260
311
# Show legend (Average Silhouette Score axis)
261
312
self .ax .legend (loc = "best" )
262
313
314
+ def _identify_silhouette_metric (self ):
315
+ """
316
+ The Silhouette metric must be one of the distance options allowed by
317
+ metrics.pairwise.pairwise_distances or a callable. This method attempts to
318
+ discover a valid distance metric from the underlying estimator or returns
319
+ "euclidean" by default.
320
+ """
321
+ if hasattr (self .estimator , "metric" ):
322
+ if callable (self .estimator .metric ):
323
+ return self .estimator .metric
324
+
325
+ if self .estimator .metric in _VALID_METRICS :
326
+ return self .estimator .metric
327
+
328
+ if hasattr (self .estimator , "affinity" ):
329
+ if self .estimator .affinity in _VALID_METRICS :
330
+ return self .estimator .affinity
331
+
332
+ return "euclidean"
333
+
263
334
264
335
##########################################################################
265
336
## Quick Method
0 commit comments