From df685e2f46c4ff75a1ff287e0eec413a5d86c4da Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 15 Jun 2026 17:34:43 +0200 Subject: [PATCH 1/5] wip: lazy load --- spikeinterface_gui/controller.py | 39 +++++++++---------- spikeinterface_gui/main.py | 6 ++- spikeinterface_gui/spikelistview.py | 12 ++++-- .../tests/test_mainwindow_panel.py | 3 +- .../tests/test_mainwindow_qt.py | 14 +++---- spikeinterface_gui/traceview.py | 3 +- 6 files changed, 41 insertions(+), 36 deletions(-) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 6f3f60c..98d3cb5 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -52,6 +52,7 @@ def __init__( curation_callback=None, curation_callback_kwargs=None, user_main_settings=None, + lazy_load=False ): self.views = [] skip_extensions = skip_extensions if skip_extensions is not None else [] @@ -100,7 +101,7 @@ def __init__( # Mandatory extensions: computation forced if verbose: print('\tLoading templates') - temp_ext = self.analyzer.get_extension("templates") + temp_ext = self.analyzer.get_extension("templates", lazy=lazy_load) if temp_ext is None: temp_ext = self.analyzer.compute_one_extension("templates") self.nbefore, self.nafter = temp_ext.nbefore, temp_ext.nafter @@ -150,9 +151,9 @@ def __init__( else: if verbose: print('\tLoading spike_amplitudes') - sa_ext = analyzer.get_extension('spike_amplitudes') + sa_ext = analyzer.get_extension('spike_amplitudes', lazy=lazy_load) if sa_ext is not None: - self.spike_amplitudes = sa_ext.get_data() + self.spike_amplitudes = sa_ext.get_data(copy=False) else: self.spike_amplitudes = None @@ -163,9 +164,9 @@ def __init__( else: if verbose: print('\tLoading amplitude_scalings') - sa_ext = analyzer.get_extension('amplitude_scalings') + sa_ext = analyzer.get_extension('amplitude_scalings', lazy=lazy_load) if sa_ext is not None: - self.amplitude_scalings = sa_ext.get_data() + self.amplitude_scalings = sa_ext.get_data(copy=False) else: self.amplitude_scalings = None @@ -176,9 +177,9 @@ def __init__( else: if verbose: print('\tLoading spike_locations') - sl_ext = analyzer.get_extension('spike_locations') + sl_ext = analyzer.get_extension('spike_locations', lazy=lazy_load) if sl_ext is not None: - self.spike_depths = sl_ext.get_data()["y"] + self.spike_depths = sl_ext.get_data(copy=False)["y"] else: self.spike_depths = None @@ -190,7 +191,7 @@ def __init__( else: if verbose: print('\tLoading correlograms') - ccg_ext = analyzer.get_extension('correlograms') + ccg_ext = analyzer.get_extension('correlograms', lazy=lazy_load) if ccg_ext is not None: self.correlograms, self.correlograms_bins = ccg_ext.get_data() else: @@ -235,7 +236,7 @@ def __init__( else: if verbose: print('\tLoading waveforms') - wf_ext = analyzer.get_extension('waveforms') + wf_ext = analyzer.get_extension('waveforms', lazy=lazy_load) if wf_ext is not None: self.waveforms_ext = wf_ext else: @@ -248,7 +249,7 @@ def __init__( else: if verbose: print('\tLoading principal_components') - pc_ext = analyzer.get_extension('principal_components') + pc_ext = analyzer.get_extension('principal_components', lazy=lazy_load) self.pc_ext = pc_ext if analyzer.has_extension("valid_unit_periods"): @@ -295,22 +296,18 @@ def __init__( unit_ids = self.analyzer.unit_ids num_seg = self.analyzer.get_num_segments() self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") - # print("self.num_spikes", self.num_spikes) - spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._extremum_channel) - # spike_vector = self.analyzer.sorting.to_spike_vector(concatenated=True) + self.spikes = self.analyzer.sorting.to_spike_vector() + print(f"spike vector: {type(self.spikes)}") self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data() - self.spikes = np.zeros(spike_vector.size, dtype=spike_dtype) - self.spikes['sample_index'] = spike_vector['sample_index'] - self.spikes['unit_index'] = spike_vector['unit_index'] - self.spikes['segment_index'] = spike_vector['segment_index'] - self.spikes['channel_index'] = spike_vector['channel_index'] - self.spikes['rand_selected'][:] = False - self.spikes['rand_selected'][self.random_spikes_indices] = True + ext_channel_inds = np.array([self._extremum_channel[unit_id] for unit_id in self.unit_ids]) + self.spike_channel_index = ext_channel_inds[self.spikes["unit_index"]] + self.spike_rand_selected = np.zeros(len(self.spikes), dtype=bool) + self.spike_rand_selected[self.random_spikes_indices] = True - # self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") + # TODO: minimize memory here seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)} diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 8f28c24..90b6f35 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -41,6 +41,7 @@ def run_mainwindow( verbose: bool = False, user_settings: dict | None = None, disable_save_settings_button: bool = False, + lazy_load: bool = False ): """ Create the main window and start the QT app loop. @@ -109,6 +110,8 @@ def run_mainwindow( A dictionary of user settings for each view, which overwrite the default settings. disable_save_settings_button: bool, default: False If True, disables the "save default settings" button, so that user cannot do this. + lazy_load: bool, default : False + If True, arrays are lazy loaded to use less RAM """ if mode == "desktop": @@ -165,7 +168,8 @@ def run_mainwindow( external_data=external_data, curation_callback=curation_callback, curation_callback_kwargs=curation_callback_kwargs, - user_main_settings=user_main_settings + user_main_settings=user_main_settings, + lazy_load=lazy_load ) if verbose: t1 = time.perf_counter() diff --git a/spikeinterface_gui/spikelistview.py b/spikeinterface_gui/spikelistview.py index c8cd73c..faca435 100644 --- a/spikeinterface_gui/spikelistview.py +++ b/spikeinterface_gui/spikelistview.py @@ -53,6 +53,8 @@ def data(self, index, role): abs_ind = self.visible_ind[row] spike = self.controller.spikes[abs_ind] + channel_index = self.controller.spike_channel_index[abs_ind] + rand_selected = self.controller.spike_rand_selected[abs_ind] unit_id = self.controller.unit_ids[spike['unit_index']] if role ==QT.Qt.DisplayRole : @@ -65,9 +67,9 @@ def data(self, index, role): elif col == 3: return '{}'.format(spike['sample_index']) elif col == 4: - return '{}'.format(spike['channel_index']) + return '{}'.format(channel_index) elif col == 5: - return '{}'.format(spike['rand_selected']) + return '{}'.format(rand_selected) else: return None elif role == QT.Qt.DecorationRole : @@ -309,6 +311,8 @@ def _panel_refresh_table(self): visible_inds = self.controller.get_indices_spike_visible() unit_ids = self.controller.unit_ids spikes = self.controller.spikes[visible_inds] + channel_inds = self.controller.spike_channel_index + rand_selected = self.controller.spike_rand_selected spike_unit_ids = [] for i, spike in enumerate(spikes): @@ -322,8 +326,8 @@ def _panel_refresh_table(self): 'unit_id': spike_unit_ids, 'segment_index': spikes['segment_index'], 'sample_index': spikes['sample_index'], - 'channel_index': spikes['channel_index'], - 'rand_selected': spikes['rand_selected'] + 'channel_index': channel_inds, + 'rand_selected': rand_selected } # Update table data without replacing entire dataframe diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index 3971078..66f7dab 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -33,8 +33,7 @@ def teardown_module(): def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False, port=0): - - analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer", load_extensions=False) # analyzer = load_analyzer(test_folder / "sorting_analyzer.zarr") print(analyzer) diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index 3349eba..2d42554 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -35,15 +35,12 @@ def teardown_module(): clean_all(test_folder) -def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False): +def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_extensions=False, events=False, lazy_load=False): - analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer", load_extensions=False, lazy=lazy_load) # analyzer = load_analyzer(test_folder / "sorting_analyzer.zarr") - tm = analyzer.get_extension("template_metrics").get_data().iloc[0, :] - # print(tm) - # return print(analyzer) @@ -109,7 +106,8 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext displayed_unit_properties=None, extra_unit_properties=extra_unit_properties, layout_preset='default', - events=events_dict + events=events_dict, + lazy_load=lazy_load # user_settings={"mainsettings": {"color_mode": "color_by_visibility", "max_visible_units": 5}} ) @@ -144,6 +142,8 @@ def test_launcher(verbose=True): parser = ArgumentParser() parser.add_argument('--dataset', default="small", help='Path to the dataset folder') parser.add_argument('--events', action="store_true", help='Simulate and add events') +parser.add_argument('--lazy', action="store_true", help='Lazy load') + if __name__ == '__main__': args = parser.parse_args() @@ -155,7 +155,7 @@ def test_launcher(verbose=True): if not test_folder.is_dir(): setup_module() - win = test_mainwindow(start_app=True, verbose=True, curation=True, events=args.events) + win = test_mainwindow(start_app=True, verbose=True, curation=True, events=args.events, lazy_load=args.lazy) # win = test_mainwindow(start_app=True, verbose=True, curation=False) # test_launcher(verbose=True) diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index e79f29d..032b21a 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -38,6 +38,7 @@ def get_data_in_chunk(self, t1, t2, segment_index): spikes_seg = self.controller.spikes[sl] i1, i2 = np.searchsorted(spikes_seg["sample_index"], [ind1, ind2]) spikes_chunk = spikes_seg[i1:i2].copy() + spikes_channel_chunk = self.controller.spike_channel_index[sl] spikes_chunk["sample_index"] -= ind1 # for trace map view, this returns the channels ordered by depth @@ -73,7 +74,7 @@ def get_data_in_chunk(self, t1, t2, segment_index): # Get spikes for this unit unit_spikes = spikes_chunk[inds] - channel_inds = unit_spikes["channel_index"] + channel_inds = spikes_channel_chunk[inds] sample_inds = unit_spikes["sample_index"] chan_mask = np.isin(channel_inds, visible_channel_inds) From efb3eae8b24d4491d4807a1c411a1d0b285ed84c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 17:09:08 +0200 Subject: [PATCH 2/5] test: option to test analyzer in lazy mode --- spikeinterface_gui/controller.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 98d3cb5..8064e2a 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -52,7 +52,6 @@ def __init__( curation_callback=None, curation_callback_kwargs=None, user_main_settings=None, - lazy_load=False ): self.views = [] skip_extensions = skip_extensions if skip_extensions is not None else [] @@ -101,7 +100,7 @@ def __init__( # Mandatory extensions: computation forced if verbose: print('\tLoading templates') - temp_ext = self.analyzer.get_extension("templates", lazy=lazy_load) + temp_ext = self.analyzer.get_extension("templates") if temp_ext is None: temp_ext = self.analyzer.compute_one_extension("templates") self.nbefore, self.nafter = temp_ext.nbefore, temp_ext.nafter @@ -151,7 +150,7 @@ def __init__( else: if verbose: print('\tLoading spike_amplitudes') - sa_ext = analyzer.get_extension('spike_amplitudes', lazy=lazy_load) + sa_ext = analyzer.get_extension('spike_amplitudes') if sa_ext is not None: self.spike_amplitudes = sa_ext.get_data(copy=False) else: @@ -164,7 +163,7 @@ def __init__( else: if verbose: print('\tLoading amplitude_scalings') - sa_ext = analyzer.get_extension('amplitude_scalings', lazy=lazy_load) + sa_ext = analyzer.get_extension('amplitude_scalings') if sa_ext is not None: self.amplitude_scalings = sa_ext.get_data(copy=False) else: @@ -177,7 +176,7 @@ def __init__( else: if verbose: print('\tLoading spike_locations') - sl_ext = analyzer.get_extension('spike_locations', lazy=lazy_load) + sl_ext = analyzer.get_extension('spike_locations') if sl_ext is not None: self.spike_depths = sl_ext.get_data(copy=False)["y"] else: @@ -191,7 +190,7 @@ def __init__( else: if verbose: print('\tLoading correlograms') - ccg_ext = analyzer.get_extension('correlograms', lazy=lazy_load) + ccg_ext = analyzer.get_extension('correlograms') if ccg_ext is not None: self.correlograms, self.correlograms_bins = ccg_ext.get_data() else: @@ -236,7 +235,7 @@ def __init__( else: if verbose: print('\tLoading waveforms') - wf_ext = analyzer.get_extension('waveforms', lazy=lazy_load) + wf_ext = analyzer.get_extension('waveforms') if wf_ext is not None: self.waveforms_ext = wf_ext else: @@ -249,7 +248,7 @@ def __init__( else: if verbose: print('\tLoading principal_components') - pc_ext = analyzer.get_extension('principal_components', lazy=lazy_load) + pc_ext = analyzer.get_extension('principal_components') self.pc_ext = pc_ext if analyzer.has_extension("valid_unit_periods"): From 1b23ee367406945712600240f83806ae0ed93c05 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 17:11:46 +0200 Subject: [PATCH 3/5] fixes --- spikeinterface_gui/controller.py | 1 - spikeinterface_gui/main.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 8064e2a..ab81034 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -297,7 +297,6 @@ def __init__( self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") self.spikes = self.analyzer.sorting.to_spike_vector() - print(f"spike vector: {type(self.spikes)}") self.random_spikes_indices = self.analyzer.get_extension("random_spikes").get_data() diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 90b6f35..b2b647f 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -41,7 +41,6 @@ def run_mainwindow( verbose: bool = False, user_settings: dict | None = None, disable_save_settings_button: bool = False, - lazy_load: bool = False ): """ Create the main window and start the QT app loop. @@ -110,8 +109,6 @@ def run_mainwindow( A dictionary of user settings for each view, which overwrite the default settings. disable_save_settings_button: bool, default: False If True, disables the "save default settings" button, so that user cannot do this. - lazy_load: bool, default : False - If True, arrays are lazy loaded to use less RAM """ if mode == "desktop": @@ -169,7 +166,6 @@ def run_mainwindow( curation_callback=curation_callback, curation_callback_kwargs=curation_callback_kwargs, user_main_settings=user_main_settings, - lazy_load=lazy_load ) if verbose: t1 = time.perf_counter() From e9244b5b18c9f5e7650e14eeb48120a81e8a61a0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 17:12:09 +0200 Subject: [PATCH 4/5] remove trailing comma --- spikeinterface_gui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index b2b647f..8f28c24 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -165,7 +165,7 @@ def run_mainwindow( external_data=external_data, curation_callback=curation_callback, curation_callback_kwargs=curation_callback_kwargs, - user_main_settings=user_main_settings, + user_main_settings=user_main_settings ) if verbose: t1 = time.perf_counter() From bea4f4588c4edd03033962912999d618e3dfcdda Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 17:21:55 +0200 Subject: [PATCH 5/5] perf: pre-load seg_limits if available --- spikeinterface_gui/controller.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index ab81034..bdaecfc 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -305,10 +305,13 @@ def __init__( self.spike_rand_selected = np.zeros(len(self.spikes), dtype=bool) self.spike_rand_selected[self.random_spikes_indices] = True - # TODO: minimize memory here - seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) + if self.analyzer.sorting._cached_spike_vector_segment_slices is not None: + seg_limits = self.analyzer.sorting._cached_spike_vector_segment_slices + else: + seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)} - + + # TODO: minimize memory here spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False) self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2] # this is dict of list because per segment spike_indices[segment_index][unit_id]