Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3dc5729
Test IBL extractors tests failing for PI update
alejoe91 Dec 29, 2025
d1a0532
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 6, 2026
33c6769
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 16, 2026
2c94bac
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 20, 2026
a40d073
Merge branch 'main' of github.com:alejoe91/spikeinterface
alejoe91 Feb 24, 2026
ef40b73
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 17, 2026
11c5812
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 24, 2026
ada53f8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 24, 2026
22ff8fd
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 25, 2026
cbc36de
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 31, 2026
6b3e373
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Apr 9, 2026
359b68b
Implement get_unit_spike_trains function
alejoe91 Apr 9, 2026
85220e5
oups
alejoe91 Apr 9, 2026
0efad83
Fix tests
alejoe91 Apr 10, 2026
b1911bf
add tests and fixes
alejoe91 Apr 10, 2026
0744705
Fix bugs in get_unit_spike_trains_in_seconds and segment keying
grahamfindlay Apr 10, 2026
c71550b
Fix lexsort avoidance check in UnitsSelectionSorting (USS)
grahamfindlay Apr 10, 2026
6a82577
Override _compute_and_cache_spike_vector in Phy/Kilosort extractors
grahamfindlay Apr 10, 2026
1d4a3ce
Optimize get_unit_spike_trains on PhySortingSegment
grahamfindlay Apr 10, 2026
9a139b5
Add tests for UnitSelectionSorting & Phy spike vector and train optim…
grahamfindlay Apr 13, 2026
832f44f
Move is_spike_vector_sorted to sorting_tools
alejoe91 Apr 14, 2026
fe15764
Merge pull request #28 from grahamfindlay/pr4502-graham
alejoe91 Apr 14, 2026
329d220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
db86f54
Merge branch 'main' into get-unit-spike-trains
alejoe91 Apr 14, 2026
1ea1d83
Add tests for `sorting_tools.is_spike_vector_sorted()`
grahamfindlay Apr 17, 2026
98fa004
Leverage single-segment nature of BasePhyKilosortSortingExtractor to …
grahamfindlay Apr 17, 2026
72d7395
Make `is_spike_vector_sorted()` chunked, add early stopping, add `ass…
grahamfindlay Apr 17, 2026
79285fb
Leverage single-segment shortcuts when possible in `UnitSelectionSort…
grahamfindlay Apr 17, 2026
c495e3a
Optimize Phy/Kilosort `_compute_and_cache_spike_vector()`.
grahamfindlay May 20, 2026
9a72870
Further optimize Phy/Kilosort `get_unit_spike_trains()`
grahamfindlay May 20, 2026
74d871a
Phy/Kilosort skips bad cluster removal if there are no bad clusters.
grahamfindlay May 20, 2026
8d288de
Optimize UnitSelectionSorting.to_spike_vector()
grahamfindlay May 21, 2026
9d4be47
Shortcut handling of "identity selection" in UnitSelectionSorting.to_…
grahamfindlay May 21, 2026
375fb23
Optimize `to_reordered_spike_vector`
grahamfindlay May 21, 2026
7f8b5e8
Merge pull request #29 from grahamfindlay/dev-pr4502
alejoe91 Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 241 additions & 43 deletions src/spikeinterface/core/basesorting.py

Large diffs are not rendered by default.

442 changes: 439 additions & 3 deletions src/spikeinterface/core/sorting_tools.py

Large diffs are not rendered by default.

150 changes: 150 additions & 0 deletions src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,131 @@ def test_BaseSorting(create_cache_folder):
assert sorting.get_annotation(annotation_name) == sorting_zarr_loaded.get_annotation(annotation_name)


def _reference_reordered_spike_vector(spikes, lexsort, num_units, num_segments):
"""Pre-optimization reference: np.lexsort + nested searchsorted.

Mirrors the implementation that lived in `to_reordered_spike_vector`
before the counting-sort rewrite. Used to assert byte-for-byte parity
of the new implementation.
"""
order = np.lexsort((spikes[lexsort[0]], spikes[lexsort[1]], spikes[lexsort[2]]))
ordered_spikes = spikes[order]

if lexsort == ("sample_index", "segment_index", "unit_index"):
slices = np.zeros((num_units, num_segments, 2), dtype=np.int64)
unit_slices = np.searchsorted(ordered_spikes["unit_index"], np.arange(num_units + 1), side="left")
for unit_index in range(num_units):
u0 = unit_slices[unit_index]
u1 = unit_slices[unit_index + 1]
seg_slices = np.searchsorted(
ordered_spikes[u0:u1]["segment_index"], np.arange(num_segments + 1), side="left"
)
for segment_index in range(num_segments):
s0 = seg_slices[segment_index]
s1 = seg_slices[segment_index + 1]
slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1]
elif lexsort == ("sample_index", "unit_index", "segment_index"):
slices = np.zeros((num_segments, num_units, 2), dtype=np.int64)
seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left")
for segment_index in range(num_segments):
s0 = seg_slices[segment_index]
s1 = seg_slices[segment_index + 1]
unit_slices = np.searchsorted(ordered_spikes[s0:s1]["unit_index"], np.arange(num_units + 1), side="left")
for unit_index in range(num_units):
u0 = unit_slices[unit_index]
u1 = unit_slices[unit_index + 1]
slices[segment_index, unit_index, :] = [s0 + u0, s0 + u1]
else:
raise ValueError(lexsort)

return ordered_spikes, order, slices


def test_to_reordered_spike_vector_parity():
"""The counting-sort rewrite must match the prior np.lexsort implementation."""
rng = np.random.default_rng(42)
num_units = 6
num_segments = 3
sampling_frequency = 30_000.0

# Build per-segment, per-unit spike trains with deliberate cotemporal spikes
# (multiple units firing at the same sample_index) so the unit-index tiebreaker
# is exercised.
spike_dicts = []
for seg in range(num_segments):
seg_dict = {}
for u in range(num_units):
n = int(rng.integers(50, 200))
times = np.sort(rng.integers(0, 10_000, size=n))
# Inject a handful of cotemporal spikes that collide with the unit-0 train.
if u > 0 and n > 5:
times[:5] = np.array([100, 200, 300, 400, 500]) + seg * 10
times = np.sort(times)
seg_dict[str(u)] = times.astype("int64")
spike_dicts.append(seg_dict)

sorting = NumpySorting.from_unit_dict(spike_dicts, sampling_frequency)
spikes = sorting.to_spike_vector()

for lexsort in [
("sample_index", "segment_index", "unit_index"),
("sample_index", "unit_index", "segment_index"),
]:
# Clear the cache between iterations so each call exercises the fresh build.
sorting._cached_lexsorted_spike_vector = {}

ordered_spikes, order, slices = sorting.to_reordered_spike_vector(
lexsort=lexsort, return_order=True, return_slices=True
)

ref_ordered, ref_order, ref_slices = _reference_reordered_spike_vector(spikes, lexsort, num_units, num_segments)

# ordered_spikes must agree with the reference exactly (cotemporal spikes
# are now ordered by unit_index — stable counting sort by bucket preserves
# the canonical unit-index ordering within each tied sample_index).
assert np.array_equal(ordered_spikes, ref_ordered), f"ordered mismatch for {lexsort}"
# The invariant `spikes[order] == ordered_spikes` must hold; the exact
# `order` permutation can differ from np.lexsort's because stable counting
# sort and np.lexsort may pick different tie-break orderings of source rows
# that map to the same destination (different source rows can carry the
# same (sample, unit, segment) triple).
assert np.array_equal(spikes[order], ordered_spikes)
assert np.array_equal(slices, ref_slices), f"slices mismatch for {lexsort}"

# Each (unit, segment) — or (segment, unit) — slice must yield exactly the
# spikes for that group, with monotonic sample_index.
if lexsort == ("sample_index", "segment_index", "unit_index"):
for u in range(num_units):
for s in range(num_segments):
s0, s1 = slices[u, s]
block = ordered_spikes[s0:s1]
assert np.all(block["unit_index"] == u)
assert np.all(block["segment_index"] == s)
assert np.all(np.diff(block["sample_index"]) >= 0)
else:
for s in range(num_segments):
for u in range(num_units):
s0, s1 = slices[s, u]
block = ordered_spikes[s0:s1]
assert np.all(block["unit_index"] == u)
assert np.all(block["segment_index"] == s)
assert np.all(np.diff(block["sample_index"]) >= 0)


def test_to_reordered_spike_vector_empty():
"""An empty sorting must round-trip through the counting-sort path."""
sorting = NumpySorting.from_unit_dict({"0": np.array([], dtype="int64")}, 30_000.0)
ordered_spikes, order, slices = sorting.to_reordered_spike_vector(
lexsort=("sample_index", "segment_index", "unit_index"),
return_order=True,
return_slices=True,
)
assert ordered_spikes.size == 0
assert order.size == 0
assert slices.shape == (1, 1, 2)
assert np.array_equal(slices, np.zeros((1, 1, 2), dtype=np.int64))


def test_npy_sorting():
sfreq = 10
spike_times_0 = {
Expand Down Expand Up @@ -310,6 +435,31 @@ def test_select_periods():
np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector())


@pytest.mark.parametrize("use_cache", [False, True])
def test_get_unit_spike_trains(use_cache):
sampling_frequency = 10_000.0
duration = 1.0
num_units = 10
sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units)

all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache)
assert isinstance(all_spike_trains, dict)
assert set(all_spike_trains.keys()) == set(sorting.unit_ids)
for unit_id in sorting.unit_ids:
spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache)
assert np.array_equal(all_spike_trains[unit_id], spiketrain)

# test with times
spike_trains_times = sorting.get_unit_spike_trains_in_seconds(unit_ids=sorting.unit_ids, use_cache=use_cache)
assert isinstance(spike_trains_times, dict)
assert set(spike_trains_times.keys()) == set(sorting.unit_ids)
for unit_id in sorting.unit_ids:
spiketrain_times = sorting.get_unit_spike_train_in_seconds(
segment_index=0, unit_id=unit_id, use_cache=use_cache
)
assert np.allclose(spike_trains_times[unit_id], spiketrain_times)


if __name__ == "__main__":
import tempfile

Expand Down
Loading
Loading