diff --git a/doc/modules/qualitymetrics/hit_miss_rate.svg b/doc/modules/qualitymetrics/hit_miss_rate.svg
new file mode 100644
index 0000000000..8526486dbc
--- /dev/null
+++ b/doc/modules/qualitymetrics/hit_miss_rate.svg
@@ -0,0 +1,380 @@
+
+
+
+
diff --git a/doc/modules/qualitymetrics/nearest_neighbor.rst b/doc/modules/qualitymetrics/nearest_neighbor.rst
index bbd8f6628a..8e9470a9ca 100644
--- a/doc/modules/qualitymetrics/nearest_neighbor.rst
+++ b/doc/modules/qualitymetrics/nearest_neighbor.rst
@@ -17,52 +17,66 @@ All options involve non-parametric calculations in PCA space.
:code:`nearest_neighbor`
------------------------
-The membership function, :math:`\rho` is defined such that for any spike :math:`g_i`` in some cluster :math:`G`, :math:`\rho(g_i) = G`.
-Additionally, the nearest neighbor function :math:`n_k(g_i)` is defined such that the output of the function is the set of :math:`k` spikes which are closest to :math:`g_i`.
+Consider the set of all spikes :math:`S` and a unit :math:`A,` (where :math:`A` is a subset of :math:`S`) with
+spikes :math:`a_i \in A`. Each spike has a principal component project, which is a vector.
+We can use these vectors to compute the Euclidean distance on :math:`S`. The :math:`k` nearest
+neighbors to a spike :math:`a_i` are the :math:`k` spikes closest to it with respect to this distance.
-For a unit associated with cluster :math:`C`, a subset of spikes are randomly drawn to form the cluster :math:`A`.
-A subset of spikes which are not in :math:`C` are drawn to form the cluster :math:`B`.
-Note that :math:`|A| = |B|`.
-The NN-hit rate for :math:`C` is then:
+We can define a hit rate and miss rate of a spike based on the notion of nearest neighbors. The hit rate tells us
+how many of :math:`a_i` s nearby neighbors are also in :math:`A`, rather than :math:`\neg A`. It is defined as
.. math::
- NN_{\textrm{hit}}(C) = \frac{1}{k} \sum_{i=1}^{k} \frac{ | \{x \in A : \rho(n_i(x)) = A \} |}{ | A | }
+ NN^{\textrm{hit}}_k(a_i) = \frac{1}{k}(\text{number of } k^{\text{th} } \text{ nearest neighbors of } a_i \text{ in } A )
+Then the hit rate of a unit is then
-Similarly, the NN-miss rate for :math:`C` is:
+.. math::
+ NN^{\textrm{hit}}_k(A) = \sum_{a_i \in A} NN^{\textrm{hit}}_k(a_i)
+
+To compute the miss rate of :math:`A`, consider the spikes :math:`b_j` which are *not* in :math:`A`. The miss rate
+measure the number of these whose nearby neighbors *are* in :math:`A`. It is defined as
.. math::
- NN_{\textrm{miss}}(C) = \frac{1}{k} \sum_{i=1}^{k} \frac{ | \{x \in B : \rho(n_i(x)) = A \} |}{ | B | }
+ NN^{\textrm{miss}}_k(b_j,A) = \frac{1}{k}(\text{number of } k^{\text{th} } \text{ nearest neighbors of } b_j \text{ in } A )
-NN-hit rate gives an estimate of contamination (an uncontaminated unit should have a high NN-hit rate).
-NN-miss rate gives an estimate of completeness.
-A more complete unit should have a low NN-miss rate.
+This can be written in terms of all the other units :math:`B, C, ...` as
-:code:`nn_isolation`
---------------------
+.. math::
+ &NN^{\textrm{miss}}_k(A) = \sum_{b \in \neg A} NN^{\textrm{miss}}_k(b,A) \\
+ &= \sum_{b_j \in B} NN^{\textrm{miss}}_k(b_j, A) + \sum_{c_l \in C} NN^{\textrm{miss}}_k(c_l, A) + \ldots
-The overall logic of this approach is to choose a cluster for which the isolation is to be computed, and compute the pairwise isolation score between the chosen cluster and every other cluster.
-The isolation score is then the minimum of the pairwise scores (the worst case).
+A visualisation of the hit and miss rates for individual spikes is shown below
-Let A and B be two clusters from sorting.
-We set :math:`|A| = |B|` by subsampling as appropriate to match the size of the smaller cluster (or the :code:`max_spikes_for_nn` parameter value, if using).
-We also restrict the waveforms to channels with significant signal.
+.. image:: hit_miss_rate.svg
+ :width: 600
+ :align: center
-The pairwise isolation between clusters A and B is then:
+
+:code:`nn_isolation`
+--------------------
+
+The pairwise isolation score between two units :math:`A` and :math:`B` is a measurement of how well
+separated they are. It can be written in terms of the hit rates, although here we only consider
+:math:`A\cup B` rather than the space of all spikes. It is then defined as
.. math::
+ NN^\textrm{isolation}_k(A,B) = \sum_{x \in A\cup B} NN^{\textrm{hit}}_k(x)
- NN_{\textrm{isolation}}(A, B) = \frac{1}{k} \sum_{i=1}^{k} \frac{ | \{x \in A \cup B : \rho(n_i(x)) = \rho(x) \} |}{ | A \cup B | }
+The isolation of a unit :math:`A` is then given by the worst (minimum) pairwise isolation score:
+.. math::
+ NN^\textrm{isolation}_k(A) = \min_{B \in \text{all other units}} NN^\textrm{isolation}_k(A,B)
-Note that nn_isolation is affected by the size of the clusters, so setting the :code:`max_spikes_for_nn` may aid downstream comparison of scores.
+If the number of spikes in each unit are uneven, the metric can give misleading results. Hence we randomly
+sample the spikes in each unit so that :math:`|A| = |B| = \text{min}(|A|,|B|, \text{max_spikes})`.
:code:`nn_noise_overlap`
------------------------
-A noise cluster is generated by randomly sampling voltage snippets from the recording.
-Following a similar procedure to that of the nn_isolation method, compute isolation between the cluster of interest and the generated noise cluster.
-noise overlap is then :math:`1 - NN_{\textrm{isolation}}`.
+A noise cluster :math:`C` is generated by randomly sampling voltage snippets from the recording.
+Following a similar procedure to that of the nn_isolation method, compute isolation between the
+cluster of interest :math:`A` and the generated noise cluster.
+noise overlap is then :math:`1 - NN^{\textrm{isolation}}_k(A,C)`.
This metric gives an indication of the contamination present in the unit cluster.
diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py
index 1c5a491bf8..8ff5d7b113 100644
--- a/src/spikeinterface/qualitymetrics/pca_metrics.py
+++ b/src/spikeinterface/qualitymetrics/pca_metrics.py
@@ -33,9 +33,7 @@
max_spikes=10000,
n_neighbors=5,
),
- nn_isolation=dict(
- max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg"
- ),
+ nn_isolation=dict(max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=5),
nn_noise_overlap=dict(
max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg"
),
@@ -107,15 +105,16 @@ def compute_pc_metrics(
pc_metrics["nn_miss_rate"] = {}
if "nn_isolation" in metric_names:
+ pc_metrics["nn_isolation"] = {}
pc_metrics["nn_unit_id"] = {}
- possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"]
+ possible_nn_metrics = ["nn_noise_overlap"]
nn_metrics = list(set(metric_names).intersection(possible_nn_metrics))
non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics))
# Compute nspikes and firing rate outside of main loop for speed
- if nn_metrics:
+ if nn_metrics or "nn_isolation" in metric_names:
n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids)
fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids)
else:
@@ -196,16 +195,10 @@ def compute_pc_metrics(
**metric_params,
)
except:
- if metric_name == "nn_isolation":
- res = (np.nan, np.nan)
- elif metric_name == "nn_noise_overlap":
+ if metric_name == "nn_noise_overlap":
res = np.nan
- if metric_name == "nn_isolation":
- nn_isolation, nn_unit_id = res
- pc_metrics["nn_isolation"][unit_id] = nn_isolation
- pc_metrics["nn_unit_id"][unit_id] = nn_unit_id
- elif metric_name == "nn_noise_overlap":
+ if metric_name == "nn_noise_overlap":
pc_metrics["nn_noise_overlap"][unit_id] = res
return pc_metrics
@@ -410,18 +403,16 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n
def nearest_neighbors_isolation(
- sorting_analyzer,
+ all_units_ids,
+ pcs_flat,
+ labels,
this_unit_id: int | str,
- n_spikes_all_units: dict = None,
- fr_all_units: dict = None,
+ n_spikes_all_units: dict,
+ fr_all_units: dict,
max_spikes: int = 1000,
min_spikes: int = 10,
min_fr: float = 0.0,
n_neighbors: int = 5,
- n_components: int = 10,
- radius_um: float = 100,
- peak_sign: str = "neg",
- min_spatial_overlap: float = 0.5,
seed=None,
):
"""
@@ -429,16 +420,18 @@ def nearest_neighbors_isolation(
Parameters
----------
- sorting_analyzer : SortingAnalyzer
- A SortingAnalyzer object.
+ all_units_ids : np.array
+ List of unit_ids for the sorting
+ pcs_flat : 2d array
+ The PCs for all spikes, organized as [num_spikes, PCs].
+ labels : 1d array
+ The unit labels for all spikes. Must have length of number of spikes.
this_unit_id : int | str
The ID for the unit to calculate these metrics for.
- n_spikes_all_units : dict, default: None
- Dictionary of the form ``{: }`` for the waveform extractor.
- Recomputed if None.
- fr_all_units : dict, default: None
- Dictionary of the form ``{: }`` for the waveform extractor.
- Recomputed if None.
+ n_spikes_all_units : dict
+ Dictionary of the form ``{: }`` for the sorting analyzer.
+ fr_all_units : dict
+ Dictionary of the form ``{: }`` for the sorting analyzer.
max_spikes : int, default: 1000
Max number of spikes to use per unit.
min_spikes : int, default: 10
@@ -451,16 +444,6 @@ def nearest_neighbors_isolation(
and are ignored when selecting other units' neighbors.
n_neighbors : int, default: 5
Number of neighbors to check membership of.
- n_components : int, default: 10
- The number of PC components to use to project the snippets to.
- radius_um : float, default: 100
- The radius, in um, that channels need to be within the peak channel to be included.
- peak_sign : "neg" | "pos" | "both", default: "neg"
- The peak_sign used to compute sparsity and neighbor units. Used if sorting_analyzer
- is not sparse already.
- min_spatial_overlap : float, default: 100
- In case sorting_analyzer is sparse, other units are selected if they share at least
- `min_spatial_overlap` times `n_target_unit_channels` with the target unit.
seed : int, default: None
Seed for random subsampling of spikes.
@@ -469,7 +452,7 @@ def nearest_neighbors_isolation(
nn_isolation : float
The calculation nearest neighbor isolation metric for `this_unit_id`.
If the unit has fewer than `min_spikes`, returns numpy.NaN instead.
- nn_unit_id : np.int16
+ nn_unit_id : float
Id of the "nearest neighbor" unit (unit with lowest isolation score from `this_unit_id`).
Notes
@@ -480,41 +463,15 @@ def nearest_neighbors_isolation(
#. Compute the isolation score with every other cluster
#. Isolation score is defined as the min of 2. (i.e. 'worst-case measure')
- The implementation of this approach is:
-
- Let A and B be two clusters from sorting.
-
- We set \\|A\\| = \\|B\\|:
-
- * | If max_spikes < \\|A\\| and max_spikes < \\|B\\|:
- | Then randomly subsample max_spikes samples from A and B.
- * | If max_spikes > min(\\|A\\|, \\|B\\|) (e.g. \\|A\\| > max_spikes > \\|B\\|):
- | Then randomly subsample min(\\|A\\|, \\|B\\|) samples from A and B.
+ We set \\|A\\| = \\|B\\| = min(\\|A\\|, \\|B\\|, max_spikes).
This is because the metric is affected by the size of the clusters being compared
independently of how well-isolated they are.
- We also restrict the waveforms to channels with significant signal.
-
- See docstring for `_compute_isolation` for the definition of isolation score.
-
References
----------
Based on isolation metric described in [Chung]_
"""
- from sklearn.decomposition import IncrementalPCA
-
- rng = np.random.default_rng(seed=seed)
-
- waveforms_ext = sorting_analyzer.get_extension("waveforms")
- assert waveforms_ext is not None, "nearest_neighbors_isolation() need extension 'waveforms'"
-
- sorting = sorting_analyzer.sorting
- all_units_ids = sorting.get_unit_ids()
- if n_spikes_all_units is None:
- n_spikes_all_units = compute_num_spikes(sorting_analyzer)
- if fr_all_units is None:
- fr_all_units = compute_firing_rates(sorting_analyzer)
# if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN
if n_spikes_all_units[this_unit_id] < min_spikes:
@@ -530,100 +487,66 @@ def nearest_neighbors_isolation(
)
return np.nan, np.nan
else:
- # first remove the units with too few spikes
- unit_ids_to_keep = np.array(
- [
- unit
- for unit in all_units_ids
- if (n_spikes_all_units[unit] >= min_spikes and fr_all_units[unit] >= min_fr)
- ]
- )
- sorting = sorting.select_units(unit_ids=unit_ids_to_keep)
+ other_unit_ids = np.delete(all_units_ids, np.argwhere(all_units_ids == this_unit_id)[0][0])
+ all_scores = np.zeros(len(other_unit_ids))
+ for other_unit_ind, other_unit_id in enumerate(other_unit_ids):
+ all_scores[other_unit_ind] = isolation_score_two_clusters(
+ pcs_flat, labels, this_unit_id, other_unit_id, n_neighbors, seed, max_spikes
+ )
+ return np.min(all_scores), other_unit_ids[np.argmin(all_scores)]
- all_units_ids = sorting.get_unit_ids()
- other_units_ids = np.setdiff1d(all_units_ids, this_unit_id)
- # get waveforms of target unit
- # waveforms_target_unit = sorting_analyzer.get_waveforms(unit_id=this_unit_id)
- waveforms_target_unit = waveforms_ext.get_waveforms_one_unit(unit_id=this_unit_id, force_dense=False)
+def isolation_score_two_clusters(pcs, labels, unit_id, other_unit_id, n_neighbors, seed, max_spikes):
+ """
+ Calculate pairwise isolation score of two units.
- n_spikes_target_unit = waveforms_target_unit.shape[0]
+ Parameters
+ ----------
+ pcs : 2d array
+ The PCs for all spikes, organized as [num_spikes, PCs].
+ labels : 1d array
+ The cluster labels for all spikes. Must have length of number of spikes.
+ unit_id : int | str
+ The ID for the unit being compared.
+ other_unit_id : int | str
+ The ID for the other unit which is being compared against.
+ max_spikes : int, default: 1000
+ Max number of spikes to use per unit.
+ n_neighbors : int, default: 5
+ Number of neighbors to check membership of.
+ seed : int, default: None
+ Seed for random subsampling of spikes.
- # find units whose signal channels (i.e. channels inside some radius around
- # the channel with largest amplitude) overlap with signal channels of the target unit
- if sorting_analyzer.is_sparse():
- sparsity = sorting_analyzer.sparsity
- else:
- sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um)
- closest_chans_target_unit = sparsity.unit_id_to_channel_indices[this_unit_id]
- n_channels_target_unit = len(closest_chans_target_unit)
- # select other units that have a minimum spatial overlap with target unit
- other_units_ids = [
- unit_id
- for unit_id in other_units_ids
- if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit))
- >= (n_channels_target_unit * min_spatial_overlap)
- ]
-
- # if no unit is within neighborhood of target unit, then just say isolation is 1 (best possible)
- if not other_units_ids:
- nn_isolation = 1
- nn_unit_id = np.nan
- # if there are units to compare, then compute isolation with each
- else:
- isolation = np.zeros(
- len(other_units_ids),
- )
- for other_unit_id in other_units_ids:
- # waveforms_other_unit = sorting_analyzer.get_waveforms(unit_id=other_unit_id)
- waveforms_other_unit = waveforms_ext.get_waveforms_one_unit(unit_id=other_unit_id, force_dense=False)
-
- n_spikes_other_unit = waveforms_other_unit.shape[0]
- closest_chans_other_unit = sparsity.unit_id_to_channel_indices[other_unit_id]
- n_snippets = np.min([n_spikes_target_unit, n_spikes_other_unit, max_spikes])
-
- # make the two clusters equal in terms of: number of spikes & channels with signal
- waveforms_target_unit_idx = rng.choice(n_spikes_target_unit, size=n_snippets, replace=False)
- waveforms_target_unit_sampled = waveforms_target_unit[waveforms_target_unit_idx]
- waveforms_other_unit_idx = rng.choice(n_spikes_other_unit, size=n_snippets, replace=False)
- waveforms_other_unit_sampled = waveforms_other_unit[waveforms_other_unit_idx]
-
- # project this unit and other unit waveforms on common subspace
- common_channel_idxs = np.intersect1d(closest_chans_target_unit, closest_chans_other_unit)
- if sorting_analyzer.is_sparse():
- # in this case, waveforms are sparse so we need to do some smart indexing
- waveforms_target_unit_sampled = waveforms_target_unit_sampled[
- :, :, np.isin(closest_chans_target_unit, common_channel_idxs)
- ]
- waveforms_other_unit_sampled = waveforms_other_unit_sampled[
- :, :, np.isin(closest_chans_other_unit, common_channel_idxs)
- ]
- else:
- waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs]
- waveforms_other_unit_sampled = waveforms_other_unit_sampled[:, :, common_channel_idxs]
-
- # compute principal components after concatenation
- all_snippets = np.concatenate(
- [
- waveforms_target_unit_sampled.reshape((n_snippets, -1)),
- waveforms_other_unit_sampled.reshape((n_snippets, -1)),
- ],
- axis=0,
- )
- pca = IncrementalPCA(n_components=n_components)
- pca.partial_fit(all_snippets)
- projected_snippets = pca.transform(all_snippets)
+ Returns
+ -------
+ isolation_score : float
+ Pairwise isolation score between the two units.
+ """
- # compute isolation
- isolation[other_unit_id == other_units_ids] = _compute_isolation(
- projected_snippets[:n_snippets, :], projected_snippets[n_snippets:, :], n_neighbors
- )
- # isolation metric is the minimum of the pairwise isolations
- # nn_unit_id is the unit with lowest isolation score
- nn_isolation = np.min(isolation)
- nn_unit_id = other_units_ids[np.argmin(isolation)]
+ rng = np.random.default_rng(seed=seed)
+ from sklearn.neighbors import NearestNeighbors
- return nn_isolation, nn_unit_id
+ pcs_A = pcs[labels == unit_id]
+ pcs_B = pcs[labels == other_unit_id]
+ n_A = len(pcs_A)
+ n_B = len(pcs_B)
+
+ # ensure sample from clusters is the same size
+ if n_A < n_B or n_A > max_spikes:
+ pcs_B = pcs_B[rng.choice(n_B, size=min(n_A, max_spikes), replace=False)]
+ else:
+ pcs_A = pcs_A[rng.choice(n_A, size=min(n_B, max_spikes), replace=False)]
+
+ X = np.concatenate((pcs_A, pcs_B), 0)
+
+ nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1, algorithm="ball_tree").fit(X)
+ _, indices = nbrs.kneighbors(X)
+
+ A_isolated = np.sum(indices[:n_A, 1:] < n_A)
+ B_isolated = np.sum(indices[n_B:, 1:] >= n_B)
+
+ isolation_score = (A_isolated + B_isolated) / ((n_neighbors) * (n_A + n_B))
+ return isolation_score
def nearest_neighbors_noise_overlap(
@@ -1029,6 +952,20 @@ def pca_metrics_one_unit(args):
pc_metrics["nn_hit_rate"] = nn_hit_rate
pc_metrics["nn_miss_rate"] = nn_miss_rate
+ if "nn_isolation" in metric_names:
+ isolation, closest_unit = nearest_neighbors_isolation(
+ all_units_ids=unit_ids,
+ pcs_flat=pcs_flat,
+ labels=labels,
+ this_unit_id=unit_id,
+ fr_all_units=fr_all_units,
+ n_spikes_all_units=n_spikes_all_units,
+ seed=seed,
+ **qm_params["nn_isolation"],
+ )
+ pc_metrics["nn_isolation"] = isolation
+ pc_metrics["nn_unit_id"] = closest_unit
+
if "silhouette" in metric_names:
silhouette_method = qm_params["silhouette"]["method"]
if "simplified" in silhouette_method:
diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
index 6ddeb02689..403e73f61e 100644
--- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
+++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
@@ -5,6 +5,8 @@
compute_pc_metrics,
)
+from spikeinterface.qualitymetrics.pca_metrics import isolation_score_two_clusters
+
def test_calculate_pc_metrics(small_sorting_analyzer):
import pandas as pd
@@ -22,3 +24,17 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
assert not np.all(np.isnan(res2[metric_name].values))
assert np.array_equal(res1[metric_name].values, res2[metric_name].values)
+
+
+def test_isolation_two_clusters():
+
+ # spikes with pca projections ,
+ pcs = np.array([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]])
+ labels = np.array([0, 0, 0, 1, 1, 1])
+
+ isolation_score = isolation_score_two_clusters(
+ pcs=pcs, labels=labels, unit_id=0, other_unit_id=1, n_neighbors=2, seed=1205, max_spikes=1000
+ )
+
+ # isolation score is (1 + 1 + 1/2 + 1/2 + 1 + 1)/6 = 5/6
+ assert isolation_score == 5 / 6