24
24
import scipy .sparse as sp
25
25
from collections .abc import Iterable
26
26
27
- from sklearn .metrics import silhouette_score
28
27
from sklearn .preprocessing import LabelEncoder
29
28
from sklearn .metrics .pairwise import pairwise_distances
29
+ from sklearn .metrics import silhouette_score , DistanceMetric
30
30
31
31
from yellowbrick .utils import KneeLocator , get_param_names
32
32
from yellowbrick .style .palettes import LINE_COLOR
44
44
45
45
46
46
# Custom colors; note that LINE_COLOR is imported above
47
- TIMING_COLOR = 'C1' # Color of timing axis, tick, label, and line
48
- METRIC_COLOR = 'C0' # Color of metric axis, tick, label, and line
47
+ TIMING_COLOR = "C1" # Color of timing axis, tick, label, and line
48
+ METRIC_COLOR = "C0" # Color of metric axis, tick, label, and line
49
49
50
50
# Keys for the color dictionary
51
51
CTIMING = "timing"
52
52
CMETRIC = "metric"
53
- CVLINE = "vline"
53
+ CVLINE = "vline"
54
54
55
55
56
56
##########################################################################
@@ -112,7 +112,7 @@ def distortion_score(X, labels, metric="euclidean"):
112
112
113
113
# Compute the square distances from the instances to the center
114
114
distances = pairwise_distances (instances , center , metric = metric )
115
- distances = distances ** 2
115
+ distances = distances ** 2
116
116
117
117
# Add the sum of square distance to the distortion
118
118
distortion += distances .sum ()
@@ -130,9 +130,6 @@ def distortion_score(X, labels, metric="euclidean"):
130
130
"calinski_harabasz" : chs ,
131
131
}
132
132
133
- DISTANCE_METRICS = ['cityblock' , 'cosine' , 'euclidean' , 'haversine' ,
134
- 'l1' , 'l2' , 'manhattan' , 'nan_euclidean' , 'precomputed' ]
135
-
136
133
137
134
class KElbowVisualizer (ClusteringScoreVisualizer ):
138
135
"""
@@ -188,8 +185,8 @@ class KElbowVisualizer(ClusteringScoreVisualizer):
188
185
distance_metric : str or callable, default='euclidean'
189
186
The metric to use when calculating distance between instances in a
190
187
feature array. If metric is a string, it must be one of the options allowed
191
- by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array itself,
192
- use metric="precomputed".
188
+ by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array
189
+ itself, use metric="precomputed".
193
190
194
191
timings : bool, default: True
195
192
Display the fitting time per k to evaluate the amount of time required
@@ -259,7 +256,7 @@ def __init__(
259
256
ax = None ,
260
257
k = 10 ,
261
258
metric = "distortion" ,
262
- distance_metric = ' euclidean' ,
259
+ distance_metric = " euclidean" ,
263
260
timings = True ,
264
261
locate_elbow = True ,
265
262
** kwargs
@@ -273,11 +270,15 @@ def __init__(
273
270
"use one of distortion, silhouette, or calinski_harabasz"
274
271
)
275
272
276
- if distance_metric not in DISTANCE_METRICS :
277
- raise YellowbrickValueError (
278
- "'{} is not a defined distance metric "
279
- "use one of the sklearn metric.pairwise.pairwise_distances"
280
- )
273
+ # Check to ensure the distance metric is valid
274
+ if not callable (distance_metric ):
275
+ try :
276
+ DistanceMetric .get_metric (distance_metric )
277
+ except ValueError as e :
278
+ raise YellowbrickValueError (
279
+ "'{} is not a defined distance metric "
280
+ "use one of the sklearn metric.pairwise.pairwise_distances"
281
+ ) from e
281
282
282
283
# Store the arguments
283
284
self .k = k
@@ -302,7 +303,7 @@ def fit(self, X, y=None, **kwargs):
302
303
``self.elbow_value`` and ``self.elbow_score`` respectively.
303
304
This method finishes up by calling draw to create the plot.
304
305
"""
305
- # Convert K into a tuple argument if an integer
306
+ # Convert K into a tuple argument if an integer
306
307
if isinstance (self .k , int ):
307
308
self .k_values_ = list (range (2 , self .k + 1 ))
308
309
elif (
@@ -340,9 +341,12 @@ def fit(self, X, y=None, **kwargs):
340
341
341
342
# Append the time and score to our plottable metrics
342
343
self .k_timers_ .append (time .time () - start )
343
- if self .metric != 'calinski_harabasz' :
344
- self .k_scores_ .append (self .scoring_metric (X , self .estimator .labels_ ,
345
- metric = self .distance_metric ))
344
+ if self .metric != "calinski_harabasz" :
345
+ self .k_scores_ .append (
346
+ self .scoring_metric (
347
+ X , self .estimator .labels_ , metric = self .distance_metric
348
+ )
349
+ )
346
350
else :
347
351
self .k_scores_ .append (self .scoring_metric (X , self .estimator .labels_ ))
348
352
@@ -437,7 +441,6 @@ def finalize(self):
437
441
self .axes [1 ].set_ylabel ("fit time (seconds)" , color = self .timing_color )
438
442
self .axes [1 ].tick_params ("y" , colors = self .timing_color )
439
443
440
-
441
444
@property
442
445
def metric_color (self ):
443
446
return self .colors [CMETRIC ]
@@ -462,6 +465,7 @@ def vline_color(self):
462
465
def vline_color (self , val ):
463
466
self .colors [CVLINE ] = val
464
467
468
+
465
469
# alias
466
470
KElbow = KElbowVisualizer
467
471
@@ -478,7 +482,7 @@ def kelbow_visualizer(
478
482
ax = None ,
479
483
k = 10 ,
480
484
metric = "distortion" ,
481
- distance_metric = ' euclidean' ,
485
+ distance_metric = " euclidean" ,
482
486
timings = True ,
483
487
locate_elbow = True ,
484
488
show = True ,
@@ -521,8 +525,8 @@ def kelbow_visualizer(
521
525
distance_metric : str or callable, default='euclidean'
522
526
The metric to use when calculating distance between instances in a
523
527
feature array. If metric is a string, it must be one of the options allowed
524
- by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array itself,
525
- use metric="precomputed".
528
+ by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array
529
+ itself, use metric="precomputed".
526
530
527
531
timings : bool, default: True
528
532
Display the fitting time per k to evaluate the amount of time required
@@ -562,7 +566,7 @@ def kelbow_visualizer(
562
566
ax = ax ,
563
567
k = k ,
564
568
metric = metric ,
565
- distance_metric = ' euclidean' ,
569
+ distance_metric = " euclidean" ,
566
570
timings = timings ,
567
571
locate_elbow = locate_elbow ,
568
572
** kwargs
0 commit comments