Note
Click here to download the full example code
Curation Tutorial¶
After spike sorting and computing quality metrics, you can automatically curate the spike sorting output using the quality metrics.
import spikeinterface as si
import spikeinterface.extractors as se
from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import compute_quality_metrics
- First, let’s download a simulated dataset
from the repo ‘https://gin.g-node.org/NeuralEnsemble/ephy_testing_data’
Let’s imagine that the ground-truth sorting is in fact the output of a sorter.
local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting = se.read_mearec(local_path)
print(recording)
print(sorting)
MEArecRecordingExtractor: 32 channels - 1 segments - 32.0kHz - 10.000s
file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
file_path: /home/docs/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5
First, we extract waveforms and compute their PC scores:
folder = 'waveforms_mearec'
we = si.extract_waveforms(recording, sorting, folder,
load_if_exists=True,
ms_before=1, ms_after=2., max_spikes_per_unit=500,
n_jobs=1, chunk_size=30000)
print(we)
pc = compute_principal_components(we, load_if_exists=True,
n_components=3, mode='by_channel_local')
print(pc)
/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface-test/checkouts/stable/examples/modules/qualitymetrics/plot_4_curation.py:32: DeprecationWarning: load_if_exists=True/false is deprcated. Use load_waveforms() instead.
we = si.extract_waveforms(recording, sorting, folder,
extract waveforms memmap: 0%| | 0/11 [00:00<?, ?it/s]
extract waveforms memmap: 100%|##########| 11/11 [00:00<00:00, 117.79it/s]
WaveformExtractor: 32 channels - 10 units - 1 segments
before:32 after:64 n_per_units:500
Fitting PCA: 0%| | 0/10 [00:00<?, ?it/s]
Fitting PCA: 20%|## | 2/10 [00:00<00:00, 11.95it/s]
Fitting PCA: 40%|#### | 4/10 [00:00<00:00, 12.56it/s]
Fitting PCA: 60%|###### | 6/10 [00:00<00:00, 12.39it/s]
Fitting PCA: 80%|######## | 8/10 [00:00<00:00, 9.96it/s]
Fitting PCA: 100%|##########| 10/10 [00:01<00:00, 7.18it/s]
Fitting PCA: 100%|##########| 10/10 [00:01<00:00, 8.60it/s]
Projecting waveforms: 0%| | 0/10 [00:00<?, ?it/s]
Projecting waveforms: 100%|##########| 10/10 [00:00<00:00, 95.05it/s]
Projecting waveforms: 100%|##########| 10/10 [00:00<00:00, 94.53it/s]
WaveformPrincipalComponent: 32 channels - 1 segments
mode: by_channel_local n_components: 3
Then we compute some quality metrics:
metrics = compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'nearest_neighbor'])
print(metrics)
Computing PCA metrics: 0%| | 0/10 [00:00<?, ?it/s]
Computing PCA metrics: 20%|## | 2/10 [00:00<00:00, 13.03it/s]
Computing PCA metrics: 40%|#### | 4/10 [00:00<00:00, 13.06it/s]
Computing PCA metrics: 60%|###### | 6/10 [00:00<00:00, 13.05it/s]
Computing PCA metrics: 80%|######## | 8/10 [00:00<00:00, 13.04it/s]
Computing PCA metrics: 100%|##########| 10/10 [00:00<00:00, 13.06it/s]
Computing PCA metrics: 100%|##########| 10/10 [00:00<00:00, 13.05it/s]
snr isi_violations_ratio ... nn_hit_rate nn_miss_rate
#0 23.739727 0.0 ... 0.995283 0.003613
#1 25.599155 0.0 ... 0.960000 0.002158
#2 13.819590 0.0 ... 0.924419 0.004274
#3 21.852650 0.0 ... 0.991667 0.000000
#4 7.467602 0.0 ... 0.973958 0.001076
#5 7.465411 0.0 ... 0.972973 0.001412
#6 20.910934 0.0 ... 0.980392 0.000360
#7 7.456506 0.0 ... 0.936364 0.007874
#8 8.052315 0.0 ... 0.969072 0.017241
#9 8.990562 0.0 ... 0.943798 0.009334
[10 rows x 5 columns]
We can now threshold each quality metric and select units based on some rules.
The easiest and most intuitive way is to use boolean masking with dataframe:
keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_rate'] < 0.05) & (metrics['nn_hit_rate'] > 0.90)
print(keep_mask)
keep_unit_ids = keep_mask[keep_mask].index.values
print(keep_unit_ids)
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface-test/checkouts/stable/examples/modules/qualitymetrics/plot_4_curation.py", line 53, in <module>
keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_rate'] < 0.05) & (metrics['nn_hit_rate'] > 0.90)
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface-test/conda/stable/lib/python3.8/site-packages/pandas/core/frame.py", line 3807, in __getitem__
indexer = self.columns.get_loc(key)
File "/home/docs/checkouts/readthedocs.org/user_builds/spikeinterface-test/conda/stable/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 3804, in get_loc
raise KeyError(key) from err
KeyError: 'isi_violations_rate'
And now let’s create a sorting that contains only curated units and save it, for example to an NPZ file.
curated_sorting = sorting.select_units(keep_unit_ids)
print(curated_sorting)
se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz')
Total running time of the script: ( 0 minutes 2.761 seconds)