diff --git a/doc/modules/qualitymetrics/hit_miss_rate.svg b/doc/modules/qualitymetrics/hit_miss_rate.svg new file mode 100644 index 0000000000..8526486dbc --- /dev/null +++ b/doc/modules/qualitymetrics/hit_miss_rate.svg @@ -0,0 +1,380 @@ + + + +a1b1Cluster ACluster BCluster Cc1 diff --git a/doc/modules/qualitymetrics/nearest_neighbor.rst b/doc/modules/qualitymetrics/nearest_neighbor.rst index bbd8f6628a..8e9470a9ca 100644 --- a/doc/modules/qualitymetrics/nearest_neighbor.rst +++ b/doc/modules/qualitymetrics/nearest_neighbor.rst @@ -17,52 +17,66 @@ All options involve non-parametric calculations in PCA space. :code:`nearest_neighbor` ------------------------ -The membership function, :math:`\rho` is defined such that for any spike :math:`g_i`` in some cluster :math:`G`, :math:`\rho(g_i) = G`. -Additionally, the nearest neighbor function :math:`n_k(g_i)` is defined such that the output of the function is the set of :math:`k` spikes which are closest to :math:`g_i`. +Consider the set of all spikes :math:`S` and a unit :math:`A,` (where :math:`A` is a subset of :math:`S`) with +spikes :math:`a_i \in A`. Each spike has a principal component project, which is a vector. +We can use these vectors to compute the Euclidean distance on :math:`S`. The :math:`k` nearest +neighbors to a spike :math:`a_i` are the :math:`k` spikes closest to it with respect to this distance. -For a unit associated with cluster :math:`C`, a subset of spikes are randomly drawn to form the cluster :math:`A`. -A subset of spikes which are not in :math:`C` are drawn to form the cluster :math:`B`. -Note that :math:`|A| = |B|`. -The NN-hit rate for :math:`C` is then: +We can define a hit rate and miss rate of a spike based on the notion of nearest neighbors. The hit rate tells us +how many of :math:`a_i` s nearby neighbors are also in :math:`A`, rather than :math:`\neg A`. It is defined as .. math:: - NN_{\textrm{hit}}(C) = \frac{1}{k} \sum_{i=1}^{k} \frac{ | \{x \in A : \rho(n_i(x)) = A \} |}{ | A | } + NN^{\textrm{hit}}_k(a_i) = \frac{1}{k}(\text{number of } k^{\text{th} } \text{ nearest neighbors of } a_i \text{ in } A ) +Then the hit rate of a unit is then -Similarly, the NN-miss rate for :math:`C` is: +.. math:: + NN^{\textrm{hit}}_k(A) = \sum_{a_i \in A} NN^{\textrm{hit}}_k(a_i) + +To compute the miss rate of :math:`A`, consider the spikes :math:`b_j` which are *not* in :math:`A`. The miss rate +measure the number of these whose nearby neighbors *are* in :math:`A`. It is defined as .. math:: - NN_{\textrm{miss}}(C) = \frac{1}{k} \sum_{i=1}^{k} \frac{ | \{x \in B : \rho(n_i(x)) = A \} |}{ | B | } + NN^{\textrm{miss}}_k(b_j,A) = \frac{1}{k}(\text{number of } k^{\text{th} } \text{ nearest neighbors of } b_j \text{ in } A ) -NN-hit rate gives an estimate of contamination (an uncontaminated unit should have a high NN-hit rate). -NN-miss rate gives an estimate of completeness. -A more complete unit should have a low NN-miss rate. +This can be written in terms of all the other units :math:`B, C, ...` as -:code:`nn_isolation` --------------------- +.. math:: + &NN^{\textrm{miss}}_k(A) = \sum_{b \in \neg A} NN^{\textrm{miss}}_k(b,A) \\ + &= \sum_{b_j \in B} NN^{\textrm{miss}}_k(b_j, A) + \sum_{c_l \in C} NN^{\textrm{miss}}_k(c_l, A) + \ldots -The overall logic of this approach is to choose a cluster for which the isolation is to be computed, and compute the pairwise isolation score between the chosen cluster and every other cluster. -The isolation score is then the minimum of the pairwise scores (the worst case). +A visualisation of the hit and miss rates for individual spikes is shown below -Let A and B be two clusters from sorting. -We set :math:`|A| = |B|` by subsampling as appropriate to match the size of the smaller cluster (or the :code:`max_spikes_for_nn` parameter value, if using). -We also restrict the waveforms to channels with significant signal. +.. image:: hit_miss_rate.svg + :width: 600 + :align: center -The pairwise isolation between clusters A and B is then: + +:code:`nn_isolation` +-------------------- + +The pairwise isolation score between two units :math:`A` and :math:`B` is a measurement of how well +separated they are. It can be written in terms of the hit rates, although here we only consider +:math:`A\cup B` rather than the space of all spikes. It is then defined as .. math:: + NN^\textrm{isolation}_k(A,B) = \sum_{x \in A\cup B} NN^{\textrm{hit}}_k(x) - NN_{\textrm{isolation}}(A, B) = \frac{1}{k} \sum_{i=1}^{k} \frac{ | \{x \in A \cup B : \rho(n_i(x)) = \rho(x) \} |}{ | A \cup B | } +The isolation of a unit :math:`A` is then given by the worst (minimum) pairwise isolation score: +.. math:: + NN^\textrm{isolation}_k(A) = \min_{B \in \text{all other units}} NN^\textrm{isolation}_k(A,B) -Note that nn_isolation is affected by the size of the clusters, so setting the :code:`max_spikes_for_nn` may aid downstream comparison of scores. +If the number of spikes in each unit are uneven, the metric can give misleading results. Hence we randomly +sample the spikes in each unit so that :math:`|A| = |B| = \text{min}(|A|,|B|, \text{max_spikes})`. :code:`nn_noise_overlap` ------------------------ -A noise cluster is generated by randomly sampling voltage snippets from the recording. -Following a similar procedure to that of the nn_isolation method, compute isolation between the cluster of interest and the generated noise cluster. -noise overlap is then :math:`1 - NN_{\textrm{isolation}}`. +A noise cluster :math:`C` is generated by randomly sampling voltage snippets from the recording. +Following a similar procedure to that of the nn_isolation method, compute isolation between the +cluster of interest :math:`A` and the generated noise cluster. +noise overlap is then :math:`1 - NN^{\textrm{isolation}}_k(A,C)`. This metric gives an indication of the contamination present in the unit cluster. diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 1c5a491bf8..8ff5d7b113 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -33,9 +33,7 @@ max_spikes=10000, n_neighbors=5, ), - nn_isolation=dict( - max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" - ), + nn_isolation=dict(max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=5), nn_noise_overlap=dict( max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" ), @@ -107,15 +105,16 @@ def compute_pc_metrics( pc_metrics["nn_miss_rate"] = {} if "nn_isolation" in metric_names: + pc_metrics["nn_isolation"] = {} pc_metrics["nn_unit_id"] = {} - possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"] + possible_nn_metrics = ["nn_noise_overlap"] nn_metrics = list(set(metric_names).intersection(possible_nn_metrics)) non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics)) # Compute nspikes and firing rate outside of main loop for speed - if nn_metrics: + if nn_metrics or "nn_isolation" in metric_names: n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) else: @@ -196,16 +195,10 @@ def compute_pc_metrics( **metric_params, ) except: - if metric_name == "nn_isolation": - res = (np.nan, np.nan) - elif metric_name == "nn_noise_overlap": + if metric_name == "nn_noise_overlap": res = np.nan - if metric_name == "nn_isolation": - nn_isolation, nn_unit_id = res - pc_metrics["nn_isolation"][unit_id] = nn_isolation - pc_metrics["nn_unit_id"][unit_id] = nn_unit_id - elif metric_name == "nn_noise_overlap": + if metric_name == "nn_noise_overlap": pc_metrics["nn_noise_overlap"][unit_id] = res return pc_metrics @@ -410,18 +403,16 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n def nearest_neighbors_isolation( - sorting_analyzer, + all_units_ids, + pcs_flat, + labels, this_unit_id: int | str, - n_spikes_all_units: dict = None, - fr_all_units: dict = None, + n_spikes_all_units: dict, + fr_all_units: dict, max_spikes: int = 1000, min_spikes: int = 10, min_fr: float = 0.0, n_neighbors: int = 5, - n_components: int = 10, - radius_um: float = 100, - peak_sign: str = "neg", - min_spatial_overlap: float = 0.5, seed=None, ): """ @@ -429,16 +420,18 @@ def nearest_neighbors_isolation( Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. + all_units_ids : np.array + List of unit_ids for the sorting + pcs_flat : 2d array + The PCs for all spikes, organized as [num_spikes, PCs]. + labels : 1d array + The unit labels for all spikes. Must have length of number of spikes. this_unit_id : int | str The ID for the unit to calculate these metrics for. - n_spikes_all_units : dict, default: None - Dictionary of the form ``{: }`` for the waveform extractor. - Recomputed if None. - fr_all_units : dict, default: None - Dictionary of the form ``{: }`` for the waveform extractor. - Recomputed if None. + n_spikes_all_units : dict + Dictionary of the form ``{: }`` for the sorting analyzer. + fr_all_units : dict + Dictionary of the form ``{: }`` for the sorting analyzer. max_spikes : int, default: 1000 Max number of spikes to use per unit. min_spikes : int, default: 10 @@ -451,16 +444,6 @@ def nearest_neighbors_isolation( and are ignored when selecting other units' neighbors. n_neighbors : int, default: 5 Number of neighbors to check membership of. - n_components : int, default: 10 - The number of PC components to use to project the snippets to. - radius_um : float, default: 100 - The radius, in um, that channels need to be within the peak channel to be included. - peak_sign : "neg" | "pos" | "both", default: "neg" - The peak_sign used to compute sparsity and neighbor units. Used if sorting_analyzer - is not sparse already. - min_spatial_overlap : float, default: 100 - In case sorting_analyzer is sparse, other units are selected if they share at least - `min_spatial_overlap` times `n_target_unit_channels` with the target unit. seed : int, default: None Seed for random subsampling of spikes. @@ -469,7 +452,7 @@ def nearest_neighbors_isolation( nn_isolation : float The calculation nearest neighbor isolation metric for `this_unit_id`. If the unit has fewer than `min_spikes`, returns numpy.NaN instead. - nn_unit_id : np.int16 + nn_unit_id : float Id of the "nearest neighbor" unit (unit with lowest isolation score from `this_unit_id`). Notes @@ -480,41 +463,15 @@ def nearest_neighbors_isolation( #. Compute the isolation score with every other cluster #. Isolation score is defined as the min of 2. (i.e. 'worst-case measure') - The implementation of this approach is: - - Let A and B be two clusters from sorting. - - We set \\|A\\| = \\|B\\|: - - * | If max_spikes < \\|A\\| and max_spikes < \\|B\\|: - | Then randomly subsample max_spikes samples from A and B. - * | If max_spikes > min(\\|A\\|, \\|B\\|) (e.g. \\|A\\| > max_spikes > \\|B\\|): - | Then randomly subsample min(\\|A\\|, \\|B\\|) samples from A and B. + We set \\|A\\| = \\|B\\| = min(\\|A\\|, \\|B\\|, max_spikes). This is because the metric is affected by the size of the clusters being compared independently of how well-isolated they are. - We also restrict the waveforms to channels with significant signal. - - See docstring for `_compute_isolation` for the definition of isolation score. - References ---------- Based on isolation metric described in [Chung]_ """ - from sklearn.decomposition import IncrementalPCA - - rng = np.random.default_rng(seed=seed) - - waveforms_ext = sorting_analyzer.get_extension("waveforms") - assert waveforms_ext is not None, "nearest_neighbors_isolation() need extension 'waveforms'" - - sorting = sorting_analyzer.sorting - all_units_ids = sorting.get_unit_ids() - if n_spikes_all_units is None: - n_spikes_all_units = compute_num_spikes(sorting_analyzer) - if fr_all_units is None: - fr_all_units = compute_firing_rates(sorting_analyzer) # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: @@ -530,100 +487,66 @@ def nearest_neighbors_isolation( ) return np.nan, np.nan else: - # first remove the units with too few spikes - unit_ids_to_keep = np.array( - [ - unit - for unit in all_units_ids - if (n_spikes_all_units[unit] >= min_spikes and fr_all_units[unit] >= min_fr) - ] - ) - sorting = sorting.select_units(unit_ids=unit_ids_to_keep) + other_unit_ids = np.delete(all_units_ids, np.argwhere(all_units_ids == this_unit_id)[0][0]) + all_scores = np.zeros(len(other_unit_ids)) + for other_unit_ind, other_unit_id in enumerate(other_unit_ids): + all_scores[other_unit_ind] = isolation_score_two_clusters( + pcs_flat, labels, this_unit_id, other_unit_id, n_neighbors, seed, max_spikes + ) + return np.min(all_scores), other_unit_ids[np.argmin(all_scores)] - all_units_ids = sorting.get_unit_ids() - other_units_ids = np.setdiff1d(all_units_ids, this_unit_id) - # get waveforms of target unit - # waveforms_target_unit = sorting_analyzer.get_waveforms(unit_id=this_unit_id) - waveforms_target_unit = waveforms_ext.get_waveforms_one_unit(unit_id=this_unit_id, force_dense=False) +def isolation_score_two_clusters(pcs, labels, unit_id, other_unit_id, n_neighbors, seed, max_spikes): + """ + Calculate pairwise isolation score of two units. - n_spikes_target_unit = waveforms_target_unit.shape[0] + Parameters + ---------- + pcs : 2d array + The PCs for all spikes, organized as [num_spikes, PCs]. + labels : 1d array + The cluster labels for all spikes. Must have length of number of spikes. + unit_id : int | str + The ID for the unit being compared. + other_unit_id : int | str + The ID for the other unit which is being compared against. + max_spikes : int, default: 1000 + Max number of spikes to use per unit. + n_neighbors : int, default: 5 + Number of neighbors to check membership of. + seed : int, default: None + Seed for random subsampling of spikes. - # find units whose signal channels (i.e. channels inside some radius around - # the channel with largest amplitude) overlap with signal channels of the target unit - if sorting_analyzer.is_sparse(): - sparsity = sorting_analyzer.sparsity - else: - sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) - closest_chans_target_unit = sparsity.unit_id_to_channel_indices[this_unit_id] - n_channels_target_unit = len(closest_chans_target_unit) - # select other units that have a minimum spatial overlap with target unit - other_units_ids = [ - unit_id - for unit_id in other_units_ids - if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) - >= (n_channels_target_unit * min_spatial_overlap) - ] - - # if no unit is within neighborhood of target unit, then just say isolation is 1 (best possible) - if not other_units_ids: - nn_isolation = 1 - nn_unit_id = np.nan - # if there are units to compare, then compute isolation with each - else: - isolation = np.zeros( - len(other_units_ids), - ) - for other_unit_id in other_units_ids: - # waveforms_other_unit = sorting_analyzer.get_waveforms(unit_id=other_unit_id) - waveforms_other_unit = waveforms_ext.get_waveforms_one_unit(unit_id=other_unit_id, force_dense=False) - - n_spikes_other_unit = waveforms_other_unit.shape[0] - closest_chans_other_unit = sparsity.unit_id_to_channel_indices[other_unit_id] - n_snippets = np.min([n_spikes_target_unit, n_spikes_other_unit, max_spikes]) - - # make the two clusters equal in terms of: number of spikes & channels with signal - waveforms_target_unit_idx = rng.choice(n_spikes_target_unit, size=n_snippets, replace=False) - waveforms_target_unit_sampled = waveforms_target_unit[waveforms_target_unit_idx] - waveforms_other_unit_idx = rng.choice(n_spikes_other_unit, size=n_snippets, replace=False) - waveforms_other_unit_sampled = waveforms_other_unit[waveforms_other_unit_idx] - - # project this unit and other unit waveforms on common subspace - common_channel_idxs = np.intersect1d(closest_chans_target_unit, closest_chans_other_unit) - if sorting_analyzer.is_sparse(): - # in this case, waveforms are sparse so we need to do some smart indexing - waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.isin(closest_chans_target_unit, common_channel_idxs) - ] - waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.isin(closest_chans_other_unit, common_channel_idxs) - ] - else: - waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] - waveforms_other_unit_sampled = waveforms_other_unit_sampled[:, :, common_channel_idxs] - - # compute principal components after concatenation - all_snippets = np.concatenate( - [ - waveforms_target_unit_sampled.reshape((n_snippets, -1)), - waveforms_other_unit_sampled.reshape((n_snippets, -1)), - ], - axis=0, - ) - pca = IncrementalPCA(n_components=n_components) - pca.partial_fit(all_snippets) - projected_snippets = pca.transform(all_snippets) + Returns + ------- + isolation_score : float + Pairwise isolation score between the two units. + """ - # compute isolation - isolation[other_unit_id == other_units_ids] = _compute_isolation( - projected_snippets[:n_snippets, :], projected_snippets[n_snippets:, :], n_neighbors - ) - # isolation metric is the minimum of the pairwise isolations - # nn_unit_id is the unit with lowest isolation score - nn_isolation = np.min(isolation) - nn_unit_id = other_units_ids[np.argmin(isolation)] + rng = np.random.default_rng(seed=seed) + from sklearn.neighbors import NearestNeighbors - return nn_isolation, nn_unit_id + pcs_A = pcs[labels == unit_id] + pcs_B = pcs[labels == other_unit_id] + n_A = len(pcs_A) + n_B = len(pcs_B) + + # ensure sample from clusters is the same size + if n_A < n_B or n_A > max_spikes: + pcs_B = pcs_B[rng.choice(n_B, size=min(n_A, max_spikes), replace=False)] + else: + pcs_A = pcs_A[rng.choice(n_A, size=min(n_B, max_spikes), replace=False)] + + X = np.concatenate((pcs_A, pcs_B), 0) + + nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1, algorithm="ball_tree").fit(X) + _, indices = nbrs.kneighbors(X) + + A_isolated = np.sum(indices[:n_A, 1:] < n_A) + B_isolated = np.sum(indices[n_B:, 1:] >= n_B) + + isolation_score = (A_isolated + B_isolated) / ((n_neighbors) * (n_A + n_B)) + return isolation_score def nearest_neighbors_noise_overlap( @@ -1029,6 +952,20 @@ def pca_metrics_one_unit(args): pc_metrics["nn_hit_rate"] = nn_hit_rate pc_metrics["nn_miss_rate"] = nn_miss_rate + if "nn_isolation" in metric_names: + isolation, closest_unit = nearest_neighbors_isolation( + all_units_ids=unit_ids, + pcs_flat=pcs_flat, + labels=labels, + this_unit_id=unit_id, + fr_all_units=fr_all_units, + n_spikes_all_units=n_spikes_all_units, + seed=seed, + **qm_params["nn_isolation"], + ) + pc_metrics["nn_isolation"] = isolation + pc_metrics["nn_unit_id"] = closest_unit + if "silhouette" in metric_names: silhouette_method = qm_params["silhouette"]["method"] if "simplified" in silhouette_method: diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..403e73f61e 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -5,6 +5,8 @@ compute_pc_metrics, ) +from spikeinterface.qualitymetrics.pca_metrics import isolation_score_two_clusters + def test_calculate_pc_metrics(small_sorting_analyzer): import pandas as pd @@ -22,3 +24,17 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + + +def test_isolation_two_clusters(): + + # spikes with pca projections , + pcs = np.array([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]]) + labels = np.array([0, 0, 0, 1, 1, 1]) + + isolation_score = isolation_score_two_clusters( + pcs=pcs, labels=labels, unit_id=0, other_unit_id=1, n_neighbors=2, seed=1205, max_spikes=1000 + ) + + # isolation score is (1 + 1 + 1/2 + 1/2 + 1 + 1)/6 = 5/6 + assert isolation_score == 5 / 6