Source code for spikelab.spike_sorting.stim_sorting.recentering

"""Stimulation time recentering.

Finds the actual stimulation artifact onset near each logged stim time
by detecting a chosen alignment point in the raw voltage traces:

* ``"abs_max"`` (default): sample with the largest ``|voltage|`` across
  channels — appropriate for monophasic pulses where there is a single
  artifact peak.
* ``"pos_peak"`` / ``"neg_peak"``: sample with the largest positive or
  most negative voltage in a top-K summed reference trace.
* ``"down_edge"``: up→down transition of a biphasic anodic-first pulse.
  First finds the negative peak in the search window, then the positive
  peak within a preceding ``prewindow_ms``, then returns the first
  positive-to-negative zero-crossing between them (falling back to the
  steepest negative slope if the signal does not cross zero).  This is
  the moment at which the stim current reverses direction — the AP
  trigger point for biphasic anodic-first protocols.
* ``"up_edge"``: symmetric version for biphasic cathodic-first pulses.

For the signed modes (``pos_peak``, ``neg_peak``, ``down_edge``,
``up_edge``) the reference trace is the *sum* of the top-K highest-
amplitude channels rather than the per-sample max.  Summing preserves
phase information (biphasic transitions add coherently across nearby
channels that see the same artifact; uncorrelated noise cancels) and
yields cleaner derivatives for edge detection.
"""

import warnings

import numpy as np


def _build_reference_trace(traces, n_reference_channels):
    """Return a single reference trace by summing the top-K channels
    by peak ``|voltage|``.

    Parameters:
        traces (np.ndarray): ``(channels, samples)``.
        n_reference_channels (int): K.  Clamped to ``[1, n_channels]``.

    Returns:
        reference (np.ndarray): Signed ``(samples,)`` array.
    """
    chan_amps = np.max(np.abs(traces), axis=1)
    k = max(1, min(int(n_reference_channels), traces.shape[0]))
    top_k_idx = np.argpartition(chan_amps, -k)[-k:]
    return np.sum(traces[top_k_idx], axis=0)


def _find_down_edge(reference, lo, hi, prewindow_ms, fs_Hz, neg_peak=None):
    """Find the up→down transition in a biphasic pulse.

    Algorithm:
      1. Find the negative peak in ``reference[lo:hi]`` (or use the
         caller-supplied ``neg_peak`` for multi-peak recentering).
      2. Find the positive peak in the window
         ``[max(lo, neg_peak - prewindow_samples), neg_peak)``.
      3. Transition = first positive-to-negative zero-crossing in
         ``reference[pos_peak:neg_peak + 1]``.
      4. If the signal does not cross zero (e.g. DC offset), fall back
         to the sample with the steepest negative slope in the same
         interval.
      5. If the pre-window is empty (negative peak at ``lo``), return
         the negative peak.
    """
    if neg_peak is None:
        neg_peak = lo + int(np.argmin(reference[lo:hi]))

    prewindow_samples = max(1, int(round(prewindow_ms * fs_Hz / 1000.0)))
    pre_lo = max(lo, neg_peak - prewindow_samples)
    pre_hi = neg_peak  # exclusive
    if pre_hi <= pre_lo:
        return neg_peak
    pos_peak = pre_lo + int(np.argmax(reference[pre_lo:pre_hi]))

    segment = reference[pos_peak : neg_peak + 1]
    # Sign transitions: +V followed by -V (or zero).  np.diff(sign) is
    # strictly negative at a + → - crossing.
    signs = np.sign(segment)
    sign_diffs = np.diff(signs)
    crossings = np.where(sign_diffs < 0)[0]
    if crossings.size > 0:
        return pos_peak + int(crossings[0])

    # Fallback: steepest negative slope inside the pos→neg interval.
    diffs = np.diff(segment)
    if diffs.size == 0:
        return neg_peak
    return pos_peak + int(np.argmin(diffs))


def _find_up_edge(reference, lo, hi, prewindow_ms, fs_Hz, pos_peak=None):
    """Symmetric to ``_find_down_edge`` for biphasic cathodic-first.

    Finds the positive peak (or uses caller-supplied ``pos_peak``), then
    the negative peak in a pre-window before it, then the first
    negative-to-positive zero-crossing between them.
    """
    if pos_peak is None:
        pos_peak = lo + int(np.argmax(reference[lo:hi]))

    prewindow_samples = max(1, int(round(prewindow_ms * fs_Hz / 1000.0)))
    pre_lo = max(lo, pos_peak - prewindow_samples)
    pre_hi = pos_peak
    if pre_hi <= pre_lo:
        return pos_peak
    neg_peak = pre_lo + int(np.argmin(reference[pre_lo:pre_hi]))

    segment = reference[neg_peak : pos_peak + 1]
    signs = np.sign(segment)
    sign_diffs = np.diff(signs)
    crossings = np.where(sign_diffs > 0)[0]
    if crossings.size > 0:
        return neg_peak + int(crossings[0])

    diffs = np.diff(segment)
    if diffs.size == 0:
        return pos_peak
    return neg_peak + int(np.argmax(diffs))


_VALID_PEAK_MODES = ("abs_max", "pos_peak", "neg_peak", "down_edge", "up_edge")
_VALID_MULTI_PEAK_SELECT = ("first", "last")


def _multi_peak_anchor(
    reference,
    lo,
    hi,
    peak_mode,
    multi_peak_select,
    multi_peak_threshold,
    multi_peak_min_separation_ms,
    fs_Hz,
):
    """Pick a single anchor sample from possibly multiple peaks in ``reference[lo:hi]``.

    Used for stimulation trains: the search window is wide enough to span
    several pulses, and a simple ``argmax``/``argmin`` would arbitrarily
    pick whichever pulse happened to have the largest amplitude.  This
    helper finds **all** local peaks in the window whose amplitude is at
    least ``multi_peak_threshold * max_peak_amplitude``, then returns the
    first (or last) one — guaranteeing alignment to a chosen end of the
    train regardless of pulse-to-pulse amplitude variation.

    The "search signal" used for peak detection depends on the
    ``peak_mode``:
      * ``abs_max``: ``|reference|`` — peaks of unsigned amplitude.
      * ``pos_peak``, ``up_edge``: positive lobe — peaks of ``reference``.
      * ``neg_peak``, ``down_edge``: negative lobe — peaks of
        ``-reference``.

    For ``down_edge`` / ``up_edge``, the returned anchor is the sample of
    the main artifact peak (negative or positive resp.) of the chosen
    pulse — not the zero-crossing.  The caller threads this anchor into
    :func:`_find_down_edge` / :func:`_find_up_edge` to compute the
    zero-crossing relative to that anchor.

    Returns:
        anchor_sample (int): Sample index in the global ``reference``
            array (i.e. already offset by ``lo``).  Falls back to single-
            peak ``argmax``/``argmin`` if no peaks are found above
            threshold (e.g. very short window or pure noise).
    """
    from scipy.signal import find_peaks  # lazy import

    segment = reference[lo:hi]
    if segment.size == 0:
        return lo

    if peak_mode == "abs_max":
        search = np.abs(segment)
    elif peak_mode in ("pos_peak", "up_edge"):
        search = np.maximum(segment, 0.0)
    else:  # neg_peak, down_edge
        search = -np.minimum(segment, 0.0)

    distance_samples = max(1, int(round(multi_peak_min_separation_ms * fs_Hz / 1000.0)))
    if search.max() <= 0:
        # All-zero or all-wrong-sign window → degrade to argmax/argmin
        if peak_mode in ("neg_peak", "down_edge"):
            return lo + int(np.argmin(segment))
        return lo + int(np.argmax(segment))

    threshold_abs = multi_peak_threshold * float(search.max())
    peak_idxs, _ = find_peaks(search, height=threshold_abs, distance=distance_samples)

    if peak_idxs.size == 0:
        # No interior local maxima above threshold (e.g. monotonic ramp).
        # Fall back to the single best sample in the window.
        if peak_mode in ("neg_peak", "down_edge"):
            return lo + int(np.argmin(segment))
        return lo + int(np.argmax(segment))

    if multi_peak_select == "first":
        chosen_local = int(peak_idxs[0])
    else:  # last
        chosen_local = int(peak_idxs[-1])
    return lo + chosen_local


[docs] def recenter_stim_times( traces, stim_times_ms, fs_Hz, max_offset_ms=50.0, *, peak_mode="abs_max", n_reference_channels=8, prewindow_ms=5.0, warn_offset_ms=3.0, multi_peak=False, multi_peak_select="first", multi_peak_threshold=0.6, multi_peak_min_separation_ms=2.0, ): """Find actual stimulation artifact times near logged stim times. For each logged stim time, searches a window of ``±max_offset_ms`` in the raw voltage traces and returns the sample at the alignment point selected by ``peak_mode``. This corrects for timing offsets between the stimulation hardware trigger log and the artifact in the recording. Parameters: traces (np.ndarray): Raw voltage traces, shape ``(channels, samples)``. stim_times_ms (array-like): Logged stimulation event times in milliseconds. Need not be sorted. fs_Hz (float): Sampling frequency in Hz. max_offset_ms (float): Radius of the search window around each logged stim time, in milliseconds. Default 50.0. peak_mode (str): Alignment target. One of: * ``"abs_max"`` (default): largest ``|voltage|`` across channels. Backward-compatible with the pre-``peak_mode`` API. * ``"pos_peak"``: largest positive voltage in the top-K summed reference trace. * ``"neg_peak"``: most negative voltage in the top-K summed reference. * ``"down_edge"``: up→down transition for biphasic anodic-first pulses (see module docstring). * ``"up_edge"``: down→up transition for biphasic cathodic-first pulses. n_reference_channels (int): Number of highest-amplitude channels summed to build the signed reference trace for non-``abs_max`` modes. Default ``8``. Ignored for ``abs_max``. prewindow_ms (float): For ``down_edge`` / ``up_edge``, radius of the pre-window in which to search for the preceding opposite-polarity peak. Default ``5.0``. warn_offset_ms (float or None): When the median ``|corrected - logged|`` shift exceeds this threshold, emit a ``UserWarning``. A large systematic shift usually means a fixed hardware-vs-log delay, a wrong time column in the stim log, or a unit mismatch (ms vs s vs samples) rather than genuine jitter. Set to ``None`` to silence. Default ``3.0`` ms — well above one-sample jitter at 20–30 kHz. multi_peak (bool): Opt-in support for multi-pulse stim trains. When ``True``, the search window is treated as potentially containing multiple stimulation pulses (e.g. a 100 Hz train), and the alignment target is the **first** or **last** qualifying pulse rather than the strongest one. Default ``False`` — preserves backward-compatible single- peak behavior. multi_peak_select (str): When ``multi_peak=True``, which qualifying peak to lock onto. ``"first"`` (default) = first pulse onset (matches "first-pulse alignment" used for train PSTHs). ``"last"`` = last pulse onset (useful for studying after-train rebound). Ignored when ``multi_peak=False``. multi_peak_threshold (float): When ``multi_peak=True``, only peaks whose amplitude is at least this fraction of the largest peak in the search window are considered "real pulses". Default ``0.6`` — accepts pulses up to 40% weaker than the strongest while still rejecting noise. multi_peak_min_separation_ms (float): When ``multi_peak=True``, the minimum spacing between candidate peaks. Prevents multi-sample peaks of a single pulse from being counted as separate pulses. Default ``2.0`` ms — well below any sensible inter-pulse interval (5 ms = 200 Hz; 10 ms = 100 Hz). Returns: corrected_ms (np.ndarray): Corrected stim times in milliseconds, same length as ``stim_times_ms``. Events whose search window extends outside the recording are clipped to the recording boundary. Notes: * When multiple stim events have overlapping search windows, each is recentered independently. * For monophasic pulses the ``*_edge`` modes degrade gracefully: the pre-window search returns the opposite polarity's noise peak and the zero-crossing fallback lands near the onset of the single artifact — but ``pos_peak`` / ``neg_peak`` will give cleaner results in that case. * For single-pulse stim, ``multi_peak=True`` degrades to the original single-peak behavior (only one peak in the window is above threshold; first==last). Set it always-on if you mix single-pulse and train conditions in one recording. """ if peak_mode not in _VALID_PEAK_MODES: raise ValueError( f"Unknown peak_mode {peak_mode!r}; " f"expected one of {_VALID_PEAK_MODES}" ) if multi_peak and multi_peak_select not in _VALID_MULTI_PEAK_SELECT: raise ValueError( f"Unknown multi_peak_select {multi_peak_select!r}; " f"expected one of {_VALID_MULTI_PEAK_SELECT}" ) if multi_peak and not (0.0 < multi_peak_threshold <= 1.0): raise ValueError( f"multi_peak_threshold must be in (0, 1]; got {multi_peak_threshold}" ) stim_times_ms = np.asarray(stim_times_ms, dtype=np.float64) n_samples = traces.shape[1] offset_samples = int(np.round(max_offset_ms * fs_Hz / 1000.0)) # Reference trace: unsigned max-of-abs for abs_max (backward compat), # signed top-K sum for all other modes. if peak_mode == "abs_max": reference = np.max(np.abs(traces), axis=0) else: reference = _build_reference_trace(traces, n_reference_channels) corrected = np.empty_like(stim_times_ms) for i, t_ms in enumerate(stim_times_ms): center = int(np.round(t_ms * fs_Hz / 1000.0)) lo = max(0, center - offset_samples) hi = min(n_samples, center + offset_samples + 1) if multi_peak: anchor = _multi_peak_anchor( reference, lo, hi, peak_mode, multi_peak_select, multi_peak_threshold, multi_peak_min_separation_ms, fs_Hz, ) if peak_mode == "abs_max": peak_sample = anchor elif peak_mode == "pos_peak": peak_sample = anchor elif peak_mode == "neg_peak": peak_sample = anchor elif peak_mode == "down_edge": # anchor IS the negative-peak sample of the chosen pulse; # pre-window search runs ahead of it for the pos peak. peak_sample = _find_down_edge( reference, lo, hi, prewindow_ms, fs_Hz, neg_peak=anchor ) else: # up_edge peak_sample = _find_up_edge( reference, lo, hi, prewindow_ms, fs_Hz, pos_peak=anchor ) else: if peak_mode == "abs_max": peak_sample = lo + int(np.argmax(reference[lo:hi])) elif peak_mode == "pos_peak": peak_sample = lo + int(np.argmax(reference[lo:hi])) elif peak_mode == "neg_peak": peak_sample = lo + int(np.argmin(reference[lo:hi])) elif peak_mode == "down_edge": peak_sample = _find_down_edge(reference, lo, hi, prewindow_ms, fs_Hz) else: # up_edge peak_sample = _find_up_edge(reference, lo, hi, prewindow_ms, fs_Hz) corrected[i] = peak_sample / fs_Hz * 1000.0 if warn_offset_ms is not None and len(stim_times_ms) > 0: median_abs_offset_ms = float(np.median(np.abs(corrected - stim_times_ms))) if median_abs_offset_ms > warn_offset_ms: warnings.warn( f"recenter_stim_times: median |offset| = " f"{median_abs_offset_ms:.2f} ms over {len(stim_times_ms)} events " f"exceeds warn_offset_ms ({warn_offset_ms} ms). This is well " f"above one-sample jitter and usually indicates a fixed " f"hardware-vs-log delay, a wrong time column in the stim log, " f"or a unit mismatch (ms vs s vs samples). Verify against a " f"test pulse, then either accept the shift or pass " f"warn_offset_ms=None to silence.", UserWarning, stacklevel=2, ) return corrected