-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_example.py
372 lines (291 loc) · 13.4 KB
/
main_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import probeinterface.plotting as pi_plot
import spikeinterface.full as si
from pathlib import Path
import matplotlib.pyplot as plt
import shutil
show_probe = True
show_preprocessing = True
show_waveform = True
load_existing_preprocessing = True
load_existing_sorting = True
load_existing_analyzer = False
base_path = Path(r"C:\Users\Joe\PycharmProjects\course2024\course-extracellular-ephys-analysis\example_data")
data_path = base_path / "rawdata" / "sub-001" / "ses-001" / "ephys"
output_path = base_path / "derivatives" / "sub-001" / "ses-001" / "ephys"
# -------------------------------------------------------------------------------------
# Loading Raw Data
# -------------------------------------------------------------------------------------
raw_recording = si.read_spikeglx(data_path)
if show_probe:
probe = raw_recording.get_probe()
pi_plot.plot_probe(probe, with_contact_id=True)
plt.show()
# Extra things to try
print(raw_recording)
# SpikeGLXRecordingExtractor: 384 channels - 30.0kHz - 1 segments - 90,000 samples - 3.00s
# int16 dtype - 65.92 MiB
# It is a SpikeGLXRecordingExtractor class
print(dir(raw_recording))
# dir() shows all class attributes and methods for the class
print(raw_recording.get_sampling_frequency())
# 30 kHz. This class method can be seen on the results from dir()
example_data = raw_recording.get_traces(start_frame=0, end_frame=1000, return_scaled=True)
# Returns the data, as a num_samples x num_channels array. We index from first (index 0) sample
# 1000th (index 999, as the end_frame is upper-bound exclusive).
# The `raw_recording` is lazy object that only loads data into memory when requested.
# Create and plot the Fourier Transform of a single channel
import numpy as np
sampling_frequency = 30000
single_channgel_data = example_data[:, 0]
num_samples = single_channgel_data.size
freqs = np.fft.fftfreq(num_samples, d=1/sampling_frequency)
signal_fft = np.fft.fft(single_channgel_data)
magnitude_fft = np.abs(signal_fft)
scale_magnitude_fft = magnitude_fft * 2 / single_channgel_data.size
plt.plot(freqs, scale_magnitude_fft)
plt.ylabel("Scaled Magnitude")
plt.xlabel("Frequency (Hz)")
plt.title("Demeaned Signal Frequency Spectrum")
plt.show()
# -------------------------------------------------------------------------------------
# Preprocessing
# -------------------------------------------------------------------------------------
preprocessed_output_path = output_path / "preprocessed"
if preprocessed_output_path.is_dir() and load_existing_preprocessing:
preprocessed_recording = si.load_extractor(preprocessed_output_path)
else:
shifted_recording = si.phase_shift(raw_recording)
filtered_recording = si.bandpass_filter(
shifted_recording, freq_min=300, freq_max=6000
)
common_referenced_recording = si.common_reference(
filtered_recording, reference="global", operator="median"
)
whitened_recording = si.whiten(
common_referenced_recording, dtype="float32",
)
preprocessed_recording = si.correct_motion(
whitened_recording, preset="kilosort_like"
) # see also 'nonrigid_accurate'
preprocessed_recording.save(folder=preprocessed_output_path, overwrite=True)
if show_preprocessing:
si.plot_traces(
preprocessed_recording,
order_channel_by_depth=True,
time_range=(2, 3),
return_scaled=True,
show_channel_ids=True,
mode="map", # "map", "line"
clim=(-10, 10), # after whitening, use (-10, 10) otherwise use (-200, 200)
)
plt.show()
# -------------------------------------------------------------------------------------
# Extra things to try - preprocessing
# -------------------------------------------------------------------------------------
# Whitening completely changes the scaling of the data.
# The data looks very similar when unscaled, as int16. This is because the range
# and precision is not changed by scaling, only the scaling of the values, placing them
# in more interpretable units. Just because we are using float rather than int16, because
# the data is acquired as int16 we do not increse the resolution of the recording simply
# by scaling.
# We can save the data with
# preprocessed_data_path = output_path / "preprocessed_data"
# preprocessed_recording.save(folder=preprocessed_data_path)
# Setting the bandpass filter minimum cutoff to zero would include freuqencies
# in the range 0 - 6000 Hz.
# Setting the bandpass filter maximum (having returned the minimum to 300) would include
# 300 - 15000 Hz. 15000 Hz is chosen as it is the Nyquist frequency (half the sampling
# rate of 30 kHz) that represents the largest detectable frequency in the recorded signal.
# using preprocessed_recording.get_traces(start_frame=0, end_frame=1000, return_scaled=True)
# (or False, 1000 samples are taken arbitarily) shows the int16 data as acquired (when False)
# or the same data scaled to microvolts. You will see that if two datapoints are the same
# value when int16, they are the same value as microvolts.
example_prepro_data = preprocessed_recording.get_traces(0, 60000)
single_channgel_data = example_prepro_data[:, 300]
standard_dev = np.std(single_channgel_data) # ideally we would measure this with spikes removed
mean_ = np.mean(single_channgel_data)
standard_dev_cutoff = mean_ - 3 * standard_dev
spike_indicies = np.where(single_channgel_data < standard_dev_cutoff)
# Using the std methods with adjustment is not good as takes multiple points on single AP
# Instead use scipy inbuilt function for the distance
import scipy
distance_between_peaks_in_samples = int(0.001 * sampling_frequency) # spikes at least 1ms apart
# However, scipy function is extremely annoying and does not have a simple threshold
# cutoff, which we want for this example. By visualising we see that the threshold is
# under-estimating the peaks because of the way prominenec works:
# We can emulate this here by re-thresholding to 3x the std
# Note, there are peak finding algorithms in other packages you can use for threshold
# cutoffs, see https://stackoverflow.com/questions/1713335/peak-finding-algorithm-for-python-scipy
# Also, in general there are better ways to find peaks, i.e. template matching
spike_indicies = scipy.signal.find_peaks(single_channgel_data * -1, # need to invert to find negative peaks
distance=distance_between_peaks_in_samples,
prominence=standard_dev_cutoff * -1)[0] # we inverted so looking for 'positive
spike_indicies = spike_indicies[single_channgel_data[spike_indicies] < standard_dev_cutoff]
plt.plot(single_channgel_data)
plt.scatter(spike_indicies, single_channgel_data[spike_indicies])
plt.show()
# -------------------------------------------------------------------------------------
# Sorting
# -------------------------------------------------------------------------------------
sorting_path = output_path / "sorting"
if (expected_filepath := sorting_path / "sorter_output" / "firings.npz").is_file() and load_existing_sorting:
sorting = si.NpzSortingExtractor(
expected_filepath
)
else:
sorting = si.run_sorter(
"mountainsort5",
preprocessed_recording,
folder=sorting_path,
remove_existing_folder=True,
filter=False,
whiten=False,
)
sorting = sorting.remove_empty_units()
sorting = si.remove_excess_spikes(
sorting, preprocessed_recording
)
si.get_default_sorter_params(sorter_name_or_class="mountainsort5")
# Sorting - Extra things to try
# -----------------------------
# Use this function to get the times of all APs for a unit.
spike_times = sorting.get_unit_spike_train(unit_id=2, return_times=True)
# use the si.post_rasters functino as below to view the unit spikes
# as a raster plot.
si.plot_rasters(sorting, unit_ids=[2])
plt.show()
# The conditional statement is as above.
# To sort with kilosort4, change "mountainsort5" to "kilosort4" above
# (and ensure you have kilosort installed).
# -------------------------------------------------------------------------------------
# Postprocessing - Sorting Analyzer
# -------------------------------------------------------------------------------------
sorting_analyzer_path = output_path / "analyzer"
quality_metrics_path = output_path / "quality_metrics.csv"
if sorting_analyzer_path.is_dir() and load_existing_analyzer:
analyzer = si.load_sorting_analyzer(sorting_analyzer_path)
else:
analyzer = si.create_sorting_analyzer(
sorting=sorting,
recording=preprocessed_recording,
radius_um=75,
method="radius",
sparse=True,
)
analyzer.compute(
"random_spikes",
method='uniform',
max_spikes_per_unit=500,
seed=None
)
analyzer.compute(
"waveforms",
ms_before=2,
ms_after=2,
)
analyzer.compute(
"templates",
ms_before=2,
ms_after=2,
operators=["average"]
)
analyzer.compute(
"spike_amplitudes",
peak_sign="neg",
)
analyzer.compute(
"noise_levels",
num_chunks_per_segment=20,
chunk_size=10000,
seed=None
)
analyzer.compute(
"principal_components",
n_components=5,
)
analyzer.compute(
"quality_metrics",
peak_sign="neg",
seed=None,
skip_pc_metrics=False,
delete_existing_metrics=False,
)
if sorting_analyzer_path.is_dir():
shutil.rmtree(sorting_analyzer_path)
analyzer.save_as(format="binary_folder", folder=sorting_analyzer_path)
quality_metrics_table = analyzer.get_extension("quality_metrics").get_data()
quality_metrics_table.to_csv(quality_metrics_path)
# -------------------------------------------------------------------------------------
# Plot Waveforms
# -------------------------------------------------------------------------------------
valid_unit_ids = analyzer.unit_ids
unit_to_show = valid_unit_ids[0]
waveforms = analyzer.get_extension("waveforms")
unit_waveform_data = waveforms.get_waveforms_one_unit(
unit_id=unit_to_show
)
print(f"The shape of the waveform data is "
f"num_waveforms x num_samples x num_channels: {unit_waveform_data.shape}"
)
single_waveform_data = unit_waveform_data[0, :, :]
if show_waveform:
single_waveform_data = unit_waveform_data[0, :, :]
plt.plot(single_waveform_data)
plt.title("Data from a single waveform")
plt.show()
templates = analyzer.get_extension("templates")
unit_template_data = templates.get_unit_template(unit_id=unit_to_show)
print(f"The template is averaged over all waveforms. The shape"
f"of the template data is num_samples x num_channels: {unit_template_data.shape}")
si.plot_unit_waveforms_density_map(analyzer, unit_ids=[unit_to_show])
plt.show()
plt.plot(unit_template_data)
plt.title(f"Template for unit: {unit_to_show}")
plt.show()
si.plot_unit_templates(analyzer, unit_ids=[unit_to_show])
plt.show()
# -------------------------------------------------------------------------------------
# Extra things to try
# -------------------------------------------------------------------------------------
# Index out the most-negtive channel and plot
unit_waveform_data = waveforms.get_waveforms_one_unit(unit_id=2) # analyzer.get_waveforms(unit_id=2)
import numpy as np
# Lets index out the data from a single action potential.
# `first_ap_data` is a num_samples x num_channels array
first_ap_data = unit_waveform_data[0, :, :]
# Let's just find the most-negative (i.e. the largest negative peak_ value
# across all timepoints, all channels). This is over a 2D array, but np.argmin
# flattens the array, and so the index of the value is this flattened array (it is
# something like 700). We need to convert it back to a 2D index (index of the
# timepoint, index of the channel). To do this, we use `unravel_index` function.
largest_neg_idx = np.argmin(first_ap_data)
largest_neg_idx = np.unravel_index(largest_neg_idx, first_ap_data.shape)
# Now, we index out the num_samples x num_channel array at the channel in which
# the most negative value was found. `largest_neg_index` has two values,
# the first is the timepoint where the minimum value was found, the second is the
# channel where the minimum value was found.
largest_neg_channel = first_ap_data[:, largest_neg_idx[1]]
plt.plot(largest_neg_channel)
plt.show()
# Now, we want to plot a dot over the plot of the AP. The axis of the AP
# plot are the voltage on the y-axis, and the sample index on the x-axis
# The sample index is the timepoint of where the negative value was found,
# as above. We can get the voltage by finding the minimum value of the AP.
# Then, we can plot as a plt.scatter(<index of min value>, <min value>).
peak_value = np.min(largest_neg_channel)
plt.plot(largest_neg_channel)
plt.scatter(largest_neg_idx[0], peak_value)
plt.show()
# Finally, we can add offsets to the waveforms for plotting purposes, to see
# how the signal loads across channels. In practice, we would use SpikeInterface's
# own `plot_traces()` for this.
num_channels = unit_template_data.shape[1]
offset = 0.2
channel_offsets = np.linspace(0, offset * num_channels, num_channels)[np.newaxis, :]
offset_template_data = unit_template_data + channel_offsets
plt.plot(offset_template_data)
plt.show()
# Templates extra things to try: make a custom template
template_waveform = np.mean(unit_waveform_data, axis=0)
plt.plot(template_waveform + channel_offsets[:, template_waveform.shape[1]])
plt.show()