"""Stimulation-aware spike sorting pipeline.
Applies pre-trained RT-Sort sequences (from the Phase 1 vanilla sort)
to a stimulation recording and returns per-event peri-stim
``SpikeSliceStack``.
The pipeline is **per-event-chunked** by default: only the peri-event
time window around each stim event (with buffers for recentering and
artifact removal) is ever materialised. Peak RAM scales with a
single chunk's working set (typically ~100-200 MB on a 1018-ch MaxOne
recording) rather than with total recording duration. The full
chunked path is used whenever the caller passes a path or a lazy
SpikeInterface recording. Pre-materialised ``np.ndarray`` input
falls back to the legacy full-recording path (caller controls memory).
Per chunk, the pipeline:
1. Read the chunk's filtered traces (from the top-level lazy
recording) and the pre-filter traces (walking up to the first
non-filter parent), DC-centering the latter per channel.
2. Recenter the stim event time(s) within the chunk using the
configured ``peak_mode`` (``"down_edge"`` for biphasic anodic-
first pulses, etc.).
3. Remove artifacts (auto 2- or 3-way polynomial split at the
negative peak / subsequent positive peak, or single fit).
4. Run ``RTSort.sort_offline`` on the cleaned chunk.
5. Extract the peri-event ``[-pre_ms, +post_ms]`` slice per event.
6. Drop the chunk; accumulate the per-event ``SpikeData`` slices.
Events whose chunks would overlap (e.g., burst / paired-pulse
protocols) are grouped into a single chunk so sort_offline sees them
together.
"""
from pathlib import Path
import numpy as np
# Extra margin beyond the peri-event + recentering + artifact-removal
# window. Gives RT-Sort's detection model a few ms to warm up before
# the peri-event region starts, so the first few samples after the
# sort_offline reset are not in the output window. The algorithm's
# internal buffer_size is 100 samples (~5 ms at 20 kHz), so 30 ms is
# a comfortable warmup — small enough that chunks for stim at 2 Hz
# (500 ms apart) do not merge into one big chunk.
_CHUNK_WARMUP_MS = 30.0
[docs]
def sort_stim_recording(
stim_recording,
rt_sort,
stim_times_ms,
pre_ms,
post_ms,
fs_Hz=None,
*,
artifact_method="polynomial",
artifact_window_ms=10.0,
saturation_threshold=None,
baseline_threshold=None,
poly_order=3,
artifact_window_only=True,
max_stim_offset_ms=50.0,
peak_mode="abs_max",
n_reference_channels=8,
prewindow_ms=5.0,
multi_peak=False,
multi_peak_select="first",
multi_peak_threshold=0.6,
multi_peak_min_separation_ms=2.0,
model=None,
model_path=None,
recording_window_ms=None,
verbose=True,
):
"""Sort spikes in a stimulation recording using pre-trained RT-Sort sequences.
Takes a raw stimulation recording and a trained ``RTSort`` object
(or path to a saved one produced by
``sort_recording(..., sorter="rt_sort")``), removes stimulation
artifacts, runs offline spike sorting, and returns a
``SpikeSliceStack`` of sorted spikes aligned to the corrected
stimulation event times.
**Memory model.** When ``stim_recording`` is a path or a lazy
SpikeInterface recording, the pipeline processes one *per-event
time chunk* at a time (peak RAM ≈ one chunk's working set,
typically 100-200 MB on MaxOne — independent of recording
duration). When ``stim_recording`` is a pre-materialised
``np.ndarray``, the full-recording path is used instead (caller
has already paid the memory cost).
Parameters:
stim_recording: The stimulation recording. Can be:
- ``str`` or ``Path`` to a recording file (Maxwell .h5 or
NWB). Chunked path.
- A SpikeInterface ``BaseRecording`` object. Chunked path.
- ``np.ndarray`` of shape ``(channels, samples)``.
Full-recording path (no chunking possible).
rt_sort: The trained RT-Sort object or path to its pickle.
stim_times_ms (array-like): Logged stimulation event times in
milliseconds.
pre_ms (float): Output peri-event window radius before each
stim event, in milliseconds.
post_ms (float): Output peri-event window radius after each
stim event, in milliseconds.
fs_Hz (float or None): Sampling frequency in Hz. Required
for ndarray input; inferred from the recording object
otherwise.
artifact_method (str): ``"polynomial"`` (default) or
``"blank"``. Passed to ``remove_stim_artifacts``.
artifact_window_ms (float): Max artifact tail duration after
the last desaturation. Default 10.0.
saturation_threshold (float or None): Saturation voltage
threshold. None auto-detects (gain-anchored from
recording metadata if available).
baseline_threshold (float or None): Baseline envelope
threshold. None auto-detects from pre-stim MAD.
poly_order (int): Polynomial order for detrend. Default 3.
artifact_window_only (bool): Only process around stim events.
Default True.
multi_peak (bool): When ``True``, enables multi-pulse-aware
recentering — the search window is interpreted as
potentially containing multiple pulses (a stim train), and
the alignment target is the first or last qualifying pulse
rather than the strongest. Default ``False``. When
``False``, behaviour is identical to the pre-multi-peak
implementation. See :func:`recenter_stim_times` for details.
multi_peak_select (str): When ``multi_peak=True``, which
qualifying peak to lock onto. ``"first"`` (default) /
``"last"``.
multi_peak_threshold (float): When ``multi_peak=True``, peaks
below this fraction of the largest peak in the search
window are ignored. Default ``0.6``.
multi_peak_min_separation_ms (float): When ``multi_peak=True``,
minimum spacing between candidate peaks. Default ``2.0``.
max_stim_offset_ms (float): Search window radius for stim
time recentering. Default 50.0.
peak_mode (str): Alignment target for ``recenter_stim_times``.
One of ``"abs_max"`` (default), ``"pos_peak"``,
``"neg_peak"``, ``"down_edge"``, ``"up_edge"``. For
biphasic anodic-first pulses where the AP is triggered at
the up→down current reversal, use ``"down_edge"``.
n_reference_channels (int): Top-K highest-amplitude channels
summed to form the signed reference trace for non-
``abs_max`` peak modes. Default 8.
prewindow_ms (float): For ``down_edge`` / ``up_edge``, radius
of the pre-window before the primary peak. Default 5.0.
model (ModelSpikeSorter or None): Detection model instance for
``load_rt_sort`` when ``rt_sort`` is a path.
model_path (str or Path or None): Path to a detection model
folder for ``load_rt_sort`` when ``rt_sort`` is a path.
recording_window_ms (tuple or None): ``(start_ms, end_ms)``
sub-window to restrict processing to. Only events whose
peri-event window falls entirely within this range are
sorted. ``None`` processes the full recording.
verbose (bool): Print progress messages. Default True.
Returns:
stim_slices (SpikeSliceStack): Event-aligned spike slice stack
with one slice per (corrected) stim event. Each slice
spans ``[-pre_ms, +post_ms]`` relative to the stim time.
"""
from ..rt_sort_runner import load_rt_sort # noqa: F401 (validates install)
stim_times_ms = np.asarray(stim_times_ms, dtype=np.float64)
# --- Load RTSort once --------------------------------------------
rt_sort_obj = _load_rt_sort(rt_sort, model, model_path, verbose)
# --- Dispatch on input type --------------------------------------
if isinstance(stim_recording, np.ndarray):
return _sort_stim_full_recording(
traces=stim_recording,
recording_obj=None,
fs_Hz=fs_Hz,
rt_sort_obj=rt_sort_obj,
stim_times_ms=stim_times_ms,
pre_ms=pre_ms,
post_ms=post_ms,
artifact_method=artifact_method,
artifact_window_ms=artifact_window_ms,
saturation_threshold=saturation_threshold,
baseline_threshold=baseline_threshold,
poly_order=poly_order,
artifact_window_only=artifact_window_only,
max_stim_offset_ms=max_stim_offset_ms,
peak_mode=peak_mode,
n_reference_channels=n_reference_channels,
prewindow_ms=prewindow_ms,
multi_peak=multi_peak,
multi_peak_select=multi_peak_select,
multi_peak_threshold=multi_peak_threshold,
multi_peak_min_separation_ms=multi_peak_min_separation_ms,
recording_window_ms=recording_window_ms,
verbose=verbose,
)
# Path or BaseRecording → chunked path.
if isinstance(stim_recording, (str, Path)):
if verbose:
print(f"Opening recording {stim_recording} (lazy, for chunked reads)...")
from ..recording_io import load_single_recording
rec = load_single_recording(stim_recording)
else:
rec = stim_recording
return _sort_stim_chunked(
recording=rec,
rt_sort_obj=rt_sort_obj,
stim_times_ms=stim_times_ms,
pre_ms=pre_ms,
post_ms=post_ms,
artifact_method=artifact_method,
artifact_window_ms=artifact_window_ms,
saturation_threshold=saturation_threshold,
baseline_threshold=baseline_threshold,
poly_order=poly_order,
artifact_window_only=artifact_window_only,
max_stim_offset_ms=max_stim_offset_ms,
peak_mode=peak_mode,
n_reference_channels=n_reference_channels,
prewindow_ms=prewindow_ms,
multi_peak=multi_peak,
multi_peak_select=multi_peak_select,
multi_peak_threshold=multi_peak_threshold,
multi_peak_min_separation_ms=multi_peak_min_separation_ms,
recording_window_ms=recording_window_ms,
verbose=verbose,
)
# ---------------------------------------------------------------------------
# Chunked path — the default when we have a lazy recording
# ---------------------------------------------------------------------------
def _sort_stim_chunked(
recording,
rt_sort_obj,
stim_times_ms,
pre_ms,
post_ms,
*,
artifact_method,
artifact_window_ms,
saturation_threshold,
baseline_threshold,
poly_order,
artifact_window_only,
max_stim_offset_ms,
peak_mode,
n_reference_channels,
prewindow_ms,
multi_peak,
multi_peak_select,
multi_peak_threshold,
multi_peak_min_separation_ms,
recording_window_ms,
verbose,
):
from ...spikedata.spikedata import SpikeData # noqa: F401
from ...spikedata.spikeslicestack import SpikeSliceStack
from .artifact_removal import remove_stim_artifacts
from .recentering import recenter_stim_times
fs_Hz = float(recording.get_sampling_frequency())
n_total_samples = int(recording.get_num_samples())
raw_parent = _find_prefilter_parent(recording)
# Chunk window budget — each chunk must encompass:
# pre_ms + max_stim_offset_ms + _CHUNK_WARMUP_MS (before each event)
# post_ms + artifact_window_ms + _CHUNK_WARMUP_MS (after each event)
# Note: ``max_stim_offset_ms`` is the search radius for *recentering*,
# which only applies before the logged stim time (we look in the
# recording to find where the artifact actually is). It is NOT
# needed in the post-window.
chunk_pre_ms = pre_ms + max_stim_offset_ms + _CHUNK_WARMUP_MS
chunk_post_ms = post_ms + artifact_window_ms + _CHUNK_WARMUP_MS
# Filter events by the recording_window_ms + actual recording bounds.
if recording_window_ms is not None:
rwin_lo, rwin_hi = recording_window_ms
else:
rwin_lo, rwin_hi = 0.0, n_total_samples / fs_Hz * 1000.0
rec_lo_ms = 0.0
rec_hi_ms = n_total_samples / fs_Hz * 1000.0
valid_mask = (
(stim_times_ms - pre_ms >= rwin_lo)
& (stim_times_ms + post_ms <= rwin_hi)
& (stim_times_ms - chunk_pre_ms >= rec_lo_ms - 1e-6)
& (stim_times_ms + chunk_post_ms <= rec_hi_ms + 1e-6)
)
n_dropped = int(np.sum(~valid_mask))
if n_dropped > 0 and verbose:
print(
f" Dropping {n_dropped} event(s) whose peri-event window would "
f"extend outside the recording / recording_window_ms bounds"
)
global_event_indices = np.flatnonzero(valid_mask)
if len(global_event_indices) == 0:
raise ValueError("No stim events left after filtering for recording bounds.")
kept_times_ms = stim_times_ms[global_event_indices]
# Group adjacent events into chunks.
groups = _group_stim_events_into_chunks(kept_times_ms, chunk_pre_ms, chunk_post_ms)
if verbose:
print(
f"Chunking {len(kept_times_ms)} stim events into {len(groups)} "
f"time chunk(s) (chunk window ≈ "
f"{chunk_pre_ms + chunk_post_ms:.0f} ms per event)"
)
# Process chunks in order; accumulate per-event results keyed by
# original stim_times_ms index so we can emit the final slice stack
# in the caller's event order.
per_event_sd = [None] * len(kept_times_ms)
per_event_global_corr_ms = np.empty(len(kept_times_ms), dtype=np.float64)
for chunk_idx, group in enumerate(groups):
group_kept = np.asarray(group, dtype=int) # indices into kept_times_ms
group_events_ms = kept_times_ms[group_kept]
chunk_lo_ms = float(group_events_ms[0] - chunk_pre_ms)
chunk_hi_ms = float(group_events_ms[-1] + chunk_post_ms)
start_frame = max(0, int(np.floor(chunk_lo_ms * fs_Hz / 1000.0)))
end_frame = min(n_total_samples, int(np.ceil(chunk_hi_ms * fs_Hz / 1000.0)))
chunk_start_ms = start_frame * 1000.0 / fs_Hz
chunk_len_ms = (end_frame - start_frame) * 1000.0 / fs_Hz
if verbose:
print(
f" chunk {chunk_idx + 1}/{len(groups)}: "
f"{len(group_events_ms)} event(s), "
f"{chunk_len_ms:.0f} ms ({end_frame - start_frame} samples)"
)
# Load filtered chunk traces.
chunk_traces = recording.get_traces(
start_frame=start_frame, end_frame=end_frame, return_scaled=True
).T.astype(np.float32, copy=False)
# Load matching pre-filter raw chunk and DC-center per channel
# (in-place to avoid a transient duplicate allocation).
chunk_raw = None
if raw_parent is not None:
chunk_raw = raw_parent.get_traces(
start_frame=start_frame, end_frame=end_frame, return_scaled=True
).T.astype(np.float32, copy=False)
chunk_raw -= np.median(chunk_raw, axis=1, keepdims=True)
# Recenter stim times within this chunk (event times in local
# chunk coordinates, ms from the chunk start).
local_event_ms = group_events_ms - chunk_start_ms
local_corrected_ms = recenter_stim_times(
chunk_traces,
local_event_ms,
fs_Hz,
max_offset_ms=max_stim_offset_ms,
peak_mode=peak_mode,
n_reference_channels=n_reference_channels,
prewindow_ms=prewindow_ms,
multi_peak=multi_peak,
multi_peak_select=multi_peak_select,
multi_peak_threshold=multi_peak_threshold,
multi_peak_min_separation_ms=multi_peak_min_separation_ms,
)
# Remove artifacts on the chunk (in-place — we don't need the
# pre-cleaning trace after this).
chunk_cleaned, _ = remove_stim_artifacts(
chunk_traces,
local_corrected_ms,
fs_Hz,
method=artifact_method,
artifact_window_ms=artifact_window_ms,
saturation_threshold=saturation_threshold,
baseline_threshold=baseline_threshold,
poly_order=poly_order,
artifact_window_only=artifact_window_only,
copy=False,
recording=recording,
raw_traces=chunk_raw,
)
del chunk_raw
# Sort offline on the cleaned chunk — reset is default True, so
# each chunk sort is independent.
sorting = rt_sort_obj.sort_offline(
recording=chunk_cleaned,
recording_window_ms=None,
return_spikeinterface_sorter=True,
verbose=False,
)
# ``sort_offline`` on an ndarray returns a NumpySorting with
# no associated recording, so we must pass n_samples explicitly.
chunk_n_samples = chunk_cleaned.shape[1]
chunk_sd = _sorting_to_spikedata(sorting, fs_Hz, n_samples=chunk_n_samples)
del chunk_cleaned, chunk_traces, sorting
# Extract peri-event SpikeData per event in this chunk.
for i_in_group, kept_idx in enumerate(group_kept):
t_local_corr = float(local_corrected_ms[i_in_group])
t_global_corr = chunk_start_ms + t_local_corr
per_event_global_corr_ms[kept_idx] = t_global_corr
slice_lo_ms = t_local_corr - pre_ms
slice_hi_ms = t_local_corr + post_ms
# ``subtime(start, end, shift_to=peak)`` returns a SpikeData
# with spike times relative to ``peak`` — so spike times
# land in [-pre_ms, +post_ms] with start_time = -pre_ms.
# Matches the output of ``SpikeData.align_to_events``.
ev_sd = chunk_sd.subtime(slice_lo_ms, slice_hi_ms, shift_to=t_local_corr)
per_event_sd[kept_idx] = ev_sd
del chunk_sd
# Guard: all kept events must have produced a slice.
missing = [i for i, sd in enumerate(per_event_sd) if sd is None]
if missing:
raise RuntimeError(
f"Internal error: {len(missing)} kept events did not receive a "
f"peri-event slice (indices into kept set: {missing[:5]}...)"
)
# Build the final SpikeSliceStack. ``times_start_to_end`` is in
# absolute recording ms (using the chunked-recentered event times).
times_start_to_end = [
(float(t - pre_ms), float(t + post_ms)) for t in per_event_global_corr_ms
]
# Resolve neuron_attributes from the RTSort if available (so the
# stack exposes per-unit metadata without duplicating it across
# slices — ``SpikeSliceStack`` strips per-slice attrs by default).
neuron_attributes = None
if per_event_sd and per_event_sd[0].neuron_attributes is not None:
neuron_attributes = per_event_sd[0].neuron_attributes
stim_slices = SpikeSliceStack(
spike_stack=list(per_event_sd),
times_start_to_end=times_start_to_end,
neuron_attributes=neuron_attributes,
)
if verbose:
n_slices = len(stim_slices.spike_stack)
print(
f"Produced SpikeSliceStack with {n_slices} slices "
f"(peri-event windows [-{pre_ms}, +{post_ms}] ms)"
)
return stim_slices
def _group_stim_events_into_chunks(times_ms, chunk_pre_ms, chunk_post_ms):
"""Group event indices whose peri-event chunk windows overlap.
Given sorted-or-unsorted event times, returns a list of lists of
indices (into the input array) such that:
* events within a group have overlapping chunk windows
``[t - chunk_pre_ms, t + chunk_post_ms]`` — so they are
naturally processed together in one chunk;
* no two groups have overlapping chunk windows.
Group boundaries are decided by sorted times; within each group
indices are returned in time order (matching the sort-order of
``times_ms``). Input indices, not times, are returned — the
caller keeps its own mapping.
"""
times_ms = np.asarray(times_ms, dtype=np.float64)
if len(times_ms) == 0:
return []
order = np.argsort(times_ms, kind="stable")
gap_threshold = chunk_pre_ms + chunk_post_ms
groups = [[int(order[0])]]
for idx in order[1:]:
prev_idx = groups[-1][-1]
if times_ms[idx] - times_ms[prev_idx] > gap_threshold:
groups.append([int(idx)])
else:
groups[-1].append(int(idx))
return groups
# ---------------------------------------------------------------------------
# Full-recording path — kept for ndarray inputs (caller has already
# materialised the traces).
# ---------------------------------------------------------------------------
def _sort_stim_full_recording(
traces,
recording_obj,
fs_Hz,
rt_sort_obj,
stim_times_ms,
pre_ms,
post_ms,
*,
artifact_method,
artifact_window_ms,
saturation_threshold,
baseline_threshold,
poly_order,
artifact_window_only,
max_stim_offset_ms,
peak_mode,
n_reference_channels,
prewindow_ms,
multi_peak,
multi_peak_select,
multi_peak_threshold,
multi_peak_min_separation_ms,
recording_window_ms,
verbose,
):
"""Full-recording path — processes the entire recording in one go.
Used when the caller passed a pre-materialised ``np.ndarray``; for
lazy recordings the chunked path is preferred (much lower peak
RAM).
"""
from .artifact_removal import remove_stim_artifacts
from .recentering import recenter_stim_times
if traces.ndim != 2:
raise ValueError(
f"Expected 2-D array (channels, samples), got shape {traces.shape}."
)
if fs_Hz is None:
raise ValueError("fs_Hz is required when stim_recording is a numpy array.")
fs_Hz = float(fs_Hz)
if verbose:
print("Recentering stim times (full-recording path)...")
corrected_stim_ms = recenter_stim_times(
traces,
stim_times_ms,
fs_Hz,
max_offset_ms=max_stim_offset_ms,
peak_mode=peak_mode,
n_reference_channels=n_reference_channels,
prewindow_ms=prewindow_ms,
multi_peak=multi_peak,
multi_peak_select=multi_peak_select,
multi_peak_threshold=multi_peak_threshold,
multi_peak_min_separation_ms=multi_peak_min_separation_ms,
)
if verbose:
offsets = corrected_stim_ms - stim_times_ms
print(
f" Stim time corrections: "
f"mean={np.mean(offsets):.2f} ms "
f"max={np.max(np.abs(offsets)):.2f} ms"
)
print(f"Removing artifacts (method={artifact_method!r})...")
cleaned, blanked_mask = remove_stim_artifacts(
traces,
corrected_stim_ms,
fs_Hz,
method=artifact_method,
artifact_window_ms=artifact_window_ms,
saturation_threshold=saturation_threshold,
baseline_threshold=baseline_threshold,
poly_order=poly_order,
artifact_window_only=artifact_window_only,
copy=True,
recording=recording_obj,
raw_traces=None,
)
if verbose:
print(f" {100.0 * np.mean(blanked_mask):.1f}% of samples blanked")
print("Running RT-Sort offline sorting on cleaned traces...")
sorting = rt_sort_obj.sort_offline(
recording=cleaned,
recording_window_ms=recording_window_ms,
return_spikeinterface_sorter=True,
verbose=verbose,
)
# sort_offline on an ndarray produces a NumpySorting without an
# associated recording; pass n_samples explicitly.
sd = _sorting_to_spikedata(sorting, fs_Hz, n_samples=cleaned.shape[1])
if verbose:
print(f" {sd.N} units, {sum(len(t) for t in sd.train)} total spikes")
print(
f"Aligning to {len(corrected_stim_ms)} stim events "
f"(window: -{pre_ms} to +{post_ms} ms)..."
)
stim_slices = sd.align_to_events(corrected_stim_ms, pre_ms, post_ms, kind="spike")
if verbose:
print(f" Produced SpikeSliceStack with {len(stim_slices.spike_stack)} slices")
return stim_slices
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _find_prefilter_parent(recording):
"""Walk a SpikeInterface preprocessing chain upward until we leave
any bandpass/highpass/lowpass filter and return the first non-filter
parent. For the standard SpikeLab chain
``BandpassFilterRecording → ScaleRecording → …`` this returns the
``ScaleRecording`` — float32 uV traces without filter ringing.
Returns ``None`` if the top-level recording is already non-filter
(nothing to walk) or if the chain cannot be traversed.
"""
if recording is None:
return None
cls_name = type(recording).__name__
if "Filter" not in cls_name:
return None
cur = recording
visited: set = set()
while cur is not None and id(cur) not in visited:
visited.add(id(cur))
if "Filter" not in type(cur).__name__:
return cur
kwargs = getattr(cur, "_kwargs", None)
if not isinstance(kwargs, dict):
return None
parent = kwargs.get("recording") or kwargs.get("parent_recording")
if parent is None:
return None
cur = parent
return None
def _load_rt_sort(rt_sort, model, model_path, verbose):
"""Load or return an RTSort object."""
if isinstance(rt_sort, (str, Path)):
if verbose:
print(f"Loading RTSort from {rt_sort}...")
from ..rt_sort_runner import load_rt_sort as _load
return _load(Path(rt_sort), model=model, model_path=model_path)
return rt_sort
def _sorting_to_spikedata(sorting, fs_Hz, n_samples=None):
"""Convert a NumpySorting to a SpikeData (lightweight, no waveforms).
Converts spike times from samples to milliseconds and builds a
minimal SpikeData. No waveform extraction or curation is
performed — the assumption is that the RTSort sequences were
already curated during the Phase 1 vanilla sorting.
Parameters:
sorting: SpikeInterface ``NumpySorting``.
fs_Hz (float): Sampling frequency in Hz.
n_samples (int or None): Duration of the sort in samples.
When a sorting was produced by ``RTSort.sort_offline`` on
a bare ndarray (as in the chunked path), it has no
associated recording, so ``sorting.get_num_samples()``
raises. Pass the chunk's sample count explicitly in that
case. When None, falls back to ``sorting.get_num_samples()``
(which requires the sorting to have an associated
recording).
"""
from ...spikedata.spikedata import SpikeData
unit_ids = sorting.get_unit_ids()
train = []
for uid in unit_ids:
spike_samples = sorting.get_unit_spike_train(uid)
spike_ms = spike_samples.astype(np.float64) / fs_Hz * 1000.0
train.append(spike_ms)
if n_samples is None:
n_samples = sorting.get_num_samples()
length_ms = n_samples / fs_Hz * 1000.0
return SpikeData(train, length=length_ms)