Skip to content

Commit da26d1e

Browse files
authored
Generic handling of sklearn distance metric (DistrictDataLabs#1300)
1 parent 5f12bc3 commit da26d1e

File tree

5 files changed

+33
-29
lines changed

5 files changed

+33
-29
lines changed

docs/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Library Dependencies
22
matplotlib>=3.3
33
scipy>=1.6.0
4-
scikit-learn>=1.0.0
4+
scikit-learn>=1.0.2
55
numpy>=1.16.0
66
cycler>=0.10.0
77
umap-learn>=0.5.1

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Dependencies
22
matplotlib>=2.0.2,!=3.0.0
33
scipy>=1.0.0
4-
scikit-learn>=1.0.0
4+
scikit-learn>=1.0.2
55
numpy>=1.16.0
66
cycler>=0.10.0
77

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ filterwarnings =
3232
[flake8]
3333
# match black maximum line length
3434
max-line-length = 88
35-
extend-ignore = E203
35+
extend-ignore = E203,E266
3636
per-file-ignores =
3737
__init__.py:F401
3838
test_*.py:F405,F403

tests/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Library Dependencies
77
matplotlib==3.4.2
88
scipy==1.8.0
9-
scikit-learn==1.0.0
9+
scikit-learn==1.0.2
1010
numpy==1.22.0
1111
cycler==0.10.0
1212

yellowbrick/cluster/elbow.py

+29-25
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
import scipy.sparse as sp
2525
from collections.abc import Iterable
2626

27-
from sklearn.metrics import silhouette_score
2827
from sklearn.preprocessing import LabelEncoder
2928
from sklearn.metrics.pairwise import pairwise_distances
29+
from sklearn.metrics import silhouette_score, DistanceMetric
3030

3131
from yellowbrick.utils import KneeLocator, get_param_names
3232
from yellowbrick.style.palettes import LINE_COLOR
@@ -44,13 +44,13 @@
4444

4545

4646
# Custom colors; note that LINE_COLOR is imported above
47-
TIMING_COLOR = 'C1' # Color of timing axis, tick, label, and line
48-
METRIC_COLOR = 'C0' # Color of metric axis, tick, label, and line
47+
TIMING_COLOR = "C1" # Color of timing axis, tick, label, and line
48+
METRIC_COLOR = "C0" # Color of metric axis, tick, label, and line
4949

5050
# Keys for the color dictionary
5151
CTIMING = "timing"
5252
CMETRIC = "metric"
53-
CVLINE = "vline"
53+
CVLINE = "vline"
5454

5555

5656
##########################################################################
@@ -112,7 +112,7 @@ def distortion_score(X, labels, metric="euclidean"):
112112

113113
# Compute the square distances from the instances to the center
114114
distances = pairwise_distances(instances, center, metric=metric)
115-
distances = distances ** 2
115+
distances = distances**2
116116

117117
# Add the sum of square distance to the distortion
118118
distortion += distances.sum()
@@ -130,9 +130,6 @@ def distortion_score(X, labels, metric="euclidean"):
130130
"calinski_harabasz": chs,
131131
}
132132

133-
DISTANCE_METRICS = ['cityblock', 'cosine', 'euclidean', 'haversine',
134-
'l1', 'l2', 'manhattan', 'nan_euclidean', 'precomputed']
135-
136133

137134
class KElbowVisualizer(ClusteringScoreVisualizer):
138135
"""
@@ -188,8 +185,8 @@ class KElbowVisualizer(ClusteringScoreVisualizer):
188185
distance_metric : str or callable, default='euclidean'
189186
The metric to use when calculating distance between instances in a
190187
feature array. If metric is a string, it must be one of the options allowed
191-
by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array itself,
192-
use metric="precomputed".
188+
by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array
189+
itself, use metric="precomputed".
193190
194191
timings : bool, default: True
195192
Display the fitting time per k to evaluate the amount of time required
@@ -259,7 +256,7 @@ def __init__(
259256
ax=None,
260257
k=10,
261258
metric="distortion",
262-
distance_metric='euclidean',
259+
distance_metric="euclidean",
263260
timings=True,
264261
locate_elbow=True,
265262
**kwargs
@@ -273,11 +270,15 @@ def __init__(
273270
"use one of distortion, silhouette, or calinski_harabasz"
274271
)
275272

276-
if distance_metric not in DISTANCE_METRICS:
277-
raise YellowbrickValueError(
278-
"'{} is not a defined distance metric "
279-
"use one of the sklearn metric.pairwise.pairwise_distances"
280-
)
273+
# Check to ensure the distance metric is valid
274+
if not callable(distance_metric):
275+
try:
276+
DistanceMetric.get_metric(distance_metric)
277+
except ValueError as e:
278+
raise YellowbrickValueError(
279+
"'{} is not a defined distance metric "
280+
"use one of the sklearn metric.pairwise.pairwise_distances"
281+
) from e
281282

282283
# Store the arguments
283284
self.k = k
@@ -302,7 +303,7 @@ def fit(self, X, y=None, **kwargs):
302303
``self.elbow_value`` and ``self.elbow_score`` respectively.
303304
This method finishes up by calling draw to create the plot.
304305
"""
305-
# Convert K into a tuple argument if an integer
306+
# Convert K into a tuple argument if an integer
306307
if isinstance(self.k, int):
307308
self.k_values_ = list(range(2, self.k + 1))
308309
elif (
@@ -340,9 +341,12 @@ def fit(self, X, y=None, **kwargs):
340341

341342
# Append the time and score to our plottable metrics
342343
self.k_timers_.append(time.time() - start)
343-
if self.metric != 'calinski_harabasz':
344-
self.k_scores_.append(self.scoring_metric(X, self.estimator.labels_,
345-
metric=self.distance_metric))
344+
if self.metric != "calinski_harabasz":
345+
self.k_scores_.append(
346+
self.scoring_metric(
347+
X, self.estimator.labels_, metric=self.distance_metric
348+
)
349+
)
346350
else:
347351
self.k_scores_.append(self.scoring_metric(X, self.estimator.labels_))
348352

@@ -437,7 +441,6 @@ def finalize(self):
437441
self.axes[1].set_ylabel("fit time (seconds)", color=self.timing_color)
438442
self.axes[1].tick_params("y", colors=self.timing_color)
439443

440-
441444
@property
442445
def metric_color(self):
443446
return self.colors[CMETRIC]
@@ -462,6 +465,7 @@ def vline_color(self):
462465
def vline_color(self, val):
463466
self.colors[CVLINE] = val
464467

468+
465469
# alias
466470
KElbow = KElbowVisualizer
467471

@@ -478,7 +482,7 @@ def kelbow_visualizer(
478482
ax=None,
479483
k=10,
480484
metric="distortion",
481-
distance_metric='euclidean',
485+
distance_metric="euclidean",
482486
timings=True,
483487
locate_elbow=True,
484488
show=True,
@@ -521,8 +525,8 @@ def kelbow_visualizer(
521525
distance_metric : str or callable, default='euclidean'
522526
The metric to use when calculating distance between instances in a
523527
feature array. If metric is a string, it must be one of the options allowed
524-
by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array itself,
525-
use metric="precomputed".
528+
by sklearn's metrics.pairwise.pairwise_distances. If X is the distance array
529+
itself, use metric="precomputed".
526530
527531
timings : bool, default: True
528532
Display the fitting time per k to evaluate the amount of time required
@@ -562,7 +566,7 @@ def kelbow_visualizer(
562566
ax=ax,
563567
k=k,
564568
metric=metric,
565-
distance_metric='euclidean',
569+
distance_metric="euclidean",
566570
timings=timings,
567571
locate_elbow=locate_elbow,
568572
**kwargs

0 commit comments

Comments
 (0)