import warnings
import numpy as np
__all__ = ["RateData"]
from .pairwise import PairwiseCompMatrix
from concurrent.futures import ThreadPoolExecutor
from .utils import (
compute_cross_correlation_with_lag,
PCA_reduction,
UMAP_reduction,
UMAP_graph_communities,
_get_attr,
_resolve_n_jobs,
)
[docs]
class RateData:
"""A 2D instantaneous firing rate matrix with unit-to-unit correlation capabilities.
Parameters:
inst_Frate_data (array): 2D array of shape (U, T). Each value is the
instantaneous firing rate. U is the number of units/neurons and
T is the number of time bins.
times (list): List of time values that each column index in
inst_Frate_data represents. For example, times = [5, 10, 15]
so inst_Frate_data column 0 is 5 ms, column 1 is 10 ms, and
column 2 is 15 ms.
neuron_attributes (list or None): List of dicts, one per unit,
containing arbitrary metadata about each neuron. None if not
provided.
Attributes:
inst_Frate_data (array): 2D array of shape (U, T). Each value is the
instantaneous firing rate. U is the number of units/neurons and
T is the number of time bins.
times (list): List of time values that each column index in
inst_Frate_data represents. For example, times = [5, 10, 15]
so inst_Frate_data column 0 is 5 ms, column 1 is 10 ms, and
column 2 is 15 ms.
neuron_attributes (list or None): List of dicts, one per unit,
containing arbitrary metadata about each neuron. None if not
provided.
N (int): Number of units in inst_Frate_data.
Notes:
- ``times`` may contain negative values when the RateData represents an
event-aligned window (e.g., times from -200 to +500 ms around a stimulus).
- ``subtime`` always treats ``start``/``end`` as literal time values.
Use ``subtime_by_index`` for index-based slicing with negative indexing.
"""
[docs]
def __init__(self, inst_Frate_data, times, neuron_attributes=None, rate_unit=None):
"""Initialize a RateData object.
Parameters:
inst_Frate_data (numpy.ndarray): Firing rate data, shape (N, T).
times (numpy.ndarray or list): Time points, length T.
neuron_attributes (list or None): Per-unit attribute dicts.
rate_unit (str or None): Unit of the rate values. Typically
``"Hz"`` (spikes/s) for ``resampled_isi`` or ``"kHz"``
(spikes/ms) for ``sliding_rate``. When *None*, the unit
is unspecified.
"""
if inst_Frate_data.ndim != 2:
raise ValueError(
f"rates must be a 2D array, got shape {inst_Frate_data.shape}"
)
if len(times) != inst_Frate_data.shape[1]:
raise ValueError(
"Number of columns in inst_Frate_data must be the same as length of times"
)
if not isinstance(times, np.ndarray):
times = np.array(times)
# Validate times: must be all-finite and monotonically non-decreasing.
# Non-finite ``times`` causes silent filter-mask failures (NaN compares
# False) downstream in ``subtime``; unsorted ``times`` produces non-
# monotonic outputs that violate the documented contract.
if len(times) > 0:
# Check finite only on numeric dtypes — string/object arrays are
# accepted by the existing API and cannot be NaN-tested.
if np.issubdtype(times.dtype, np.number) and not np.all(np.isfinite(times)):
raise ValueError(
"times must be all-finite (no NaN or inf). Got at least "
"one non-finite entry."
)
if np.issubdtype(times.dtype, np.number) and not np.all(
np.diff(times) >= 0
):
raise ValueError(
"times must be monotonically non-decreasing. Sort the "
"array (and the matching columns of inst_Frate_data) "
"before constructing RateData."
)
self.inst_Frate_data = np.array(inst_Frate_data, dtype=float)
self.times = times
self.rate_unit = rate_unit
self.N = inst_Frate_data.shape[0]
self.neuron_attributes = None
if neuron_attributes is not None:
self.neuron_attributes = list(neuron_attributes)
if len(neuron_attributes) != self.N:
raise ValueError(
f"neuron_attributes has {len(neuron_attributes)} items "
f"but inst_Frate_data has {self.N} rows"
)
def __repr__(self) -> str:
t0 = float(self.times[0]) if len(self.times) > 0 else 0.0
t1 = float(self.times[-1]) if len(self.times) > 0 else 0.0
return f"RateData(shape={self.inst_Frate_data.shape}, time_range=[{t0:.1f}, {t1:.1f}])"
[docs]
def subset(self, units, by=None):
"""Extract a subset of units/neurons from the rate data.
Parameters:
units (list or array): Unit indices to extract. If by is None,
this should always be a list of ints. If by is not None,
the list can contain ints or strings.
by (str or None): Neuron attribute key to match against. Only
use this if you initialized the object with
neuron_attributes. Set to the key that contains neuron_id
values. None selects by index (default).
Returns:
result (RateData): New RateData object containing only the
specified units.
"""
if isinstance(units, int):
units = [units]
# For case where user inputs a single string for units when using by option
if isinstance(units, str):
units = [units]
units = set(units)
if by is not None:
# VALUE-BASED: Look up by neuron_attribute
if self.neuron_attributes is None:
raise ValueError("can't use `by` without `neuron_attributes`")
_missing = object()
units = {
i
for i in range(self.N)
if _get_attr(self.neuron_attributes[i], by, _missing) in units
}
units = sorted(units)
output = self.inst_Frate_data[units, :]
neuron_attributes = None
if self.neuron_attributes is not None:
neuron_attributes = [self.neuron_attributes[i] for i in units]
return RateData(
inst_Frate_data=output,
times=self.times,
neuron_attributes=neuron_attributes,
rate_unit=self.rate_unit,
)
[docs]
def subtime(self, start, end):
"""Extract a subset of time points from the rate data using time values.
Original time values are preserved in the output.
Parameters:
start (int or float): Starting time value (inclusive).
end (int or float): Ending time value (exclusive).
Returns:
result (RateData): New RateData object containing only the
specified time range.
Notes:
- Start and end are always treated as literal time values (not
offsets from the end). To slice by array index with negative
indexing support, use ``subtime_by_index(start_idx, end_idx)``.
"""
if len(self.times) == 0:
raise ValueError(
f"cannot apply subtime to RateData with empty times array "
f"(requested [{start}, {end}])"
)
# Handle start
if start is None or start is Ellipsis:
start = self.times[0] if len(self.times) > 0 else 0
# Handle end — use a value just past the last time point so the
# mask (times < end) includes the final bin.
if end is None or end is Ellipsis:
if len(self.times) > 1:
end = self.times[-1] + (self.times[1] - self.times[0])
elif len(self.times) == 1:
end = self.times[-1] + 1
else:
end = 0
# Validate
if start >= end:
raise ValueError(f"start ({start}) must be less than end ({end})")
mask = (self.times >= start) & (self.times < end)
# Check if start and end were in range
if not np.any(mask):
raise ValueError(
f"No time points found in range [{start}, {end}). "
f"The available range is [{self.times[0]}, {self.times[-1]}]"
)
output = self.inst_Frate_data[:, mask]
new_times = self.times[mask]
return RateData(
inst_Frate_data=output,
times=new_times,
neuron_attributes=self.neuron_attributes,
rate_unit=self.rate_unit,
)
[docs]
def subtime_by_index(self, start_idx, end_idx):
"""Extract a subset of time points from the rate data using time index values.
Original time values are preserved in the output.
Parameters:
start_idx (int): Starting time index (inclusive).
end_idx (int): Ending time index (exclusive).
Returns:
result (RateData): New RateData object containing only the
specified time range.
Notes:
- Supports negative indexing (e.g., -5 selects 5 from the end).
- To slice by time values instead of array indices, use
``subtime(start, end)``.
"""
if start_idx < 0:
start_idx += len(self.times)
if end_idx < 0:
end_idx += len(self.times)
if start_idx < 0 or start_idx >= len(self.times):
raise ValueError(f"start_idx {start_idx} out of range")
if end_idx <= start_idx or end_idx > len(self.times):
raise ValueError(f"end_idx {end_idx} invalid")
output = self.inst_Frate_data[:, start_idx:end_idx]
new_times = self.times[start_idx:end_idx]
return RateData(
inst_Frate_data=output,
times=new_times,
neuron_attributes=self.neuron_attributes,
rate_unit=self.rate_unit,
)
[docs]
def frames(self, length, overlap=0):
"""Split the rate data into a RateSliceStack of fixed-length windows.
Parameters:
length (float): Length of each window in milliseconds.
overlap (float): Overlap between consecutive windows in
milliseconds. Default 0. Must be in ``[0, length)``.
Returns:
stack (RateSliceStack): Stack of rate data windows, one per frame.
Notes:
- Windows that would extend past the end of the recording are
excluded.
- overlap must be non-negative and strictly less than length.
Negative overlap (i.e. gaps between windows) is rejected
because the parameter semantically means an overlap, not a
stride.
"""
from .rateslicestack import RateSliceStack
if overlap < 0:
raise ValueError(
f"overlap must be non-negative, got {overlap}. The parameter "
"represents an overlap, not a stride; use a smaller `length` "
"and post-filter slices for gapped windows."
)
step = length - overlap
if step <= 0:
raise ValueError("overlap must be less than length")
if len(self.times) < 2:
raise ValueError(
"Cannot frame a RateData with fewer than 2 time points; "
f"got {len(self.times)}. At least 2 time points are "
"required to infer the bin step_size."
)
t0 = float(self.times[0])
t_end = float(self.times[-1])
step_size = float(self.times[1] - self.times[0])
# frames() places windows on the uniform grid implied by
# times[1] - times[0]. Validate the rest of the grid matches —
# non-uniform times silently misalign frame boundaries with
# bins, and the downstream np.stack in RateSliceStack fails
# opaquely on the resulting shape-mismatched frames.
diffs = np.diff(np.asarray(self.times, dtype=float))
if not np.allclose(diffs, step_size, rtol=1e-6, atol=1e-9):
raise ValueError(
"RateData.frames requires uniformly-spaced times; got "
f"min step {diffs.min():g}, max step {diffs.max():g} "
f"(first-pair step {step_size:g}). Resample to a "
"uniform grid before framing."
)
upper = t_end - length + step_size + 1e-9
times = [
(float(start), float(start) + length)
for start in np.arange(t0, upper, step)
]
if not times:
raise ValueError(
f"Recording length ({t_end - t0 + step_size:.1f} ms) is shorter "
f"than frame length ({length} ms)"
)
return RateSliceStack(self, times_start_to_end=times)
[docs]
def get_pairwise_fr_corr(
self, compare_func=compute_cross_correlation_with_lag, max_lag=10, n_jobs=-1
):
"""Compute unit-to-unit similarity from the firing rate matrix (U, T).
Parameters:
compare_func (callable): Comparison function from utils. Specify
cross-correlation or cosine similarity. The default is cross
correlation. See utils.py for details.
max_lag (int): Max number of lag steps around 0 to consider for
finding the max correlation. If None, lag is set to 0.
n_jobs (int): Number of threads for parallel computation. -1 uses
all cores (default), 1 disables parallelism, None is serial.
Returns:
corr_matrix (PairwiseCompMatrix): Maximum correlation coefficients
between all unit/neuron pairs. matrix[i, j] is the max
correlation between unit i and unit j. Values range from -1
to 1. Diagonal is always 1 (self-correlation).
lag_matrix (PairwiseCompMatrix): Time lags (in time bins) at which
maximum correlation occurs. lag_matrix[i, j] is the lag where
correlation between i and j is maximal. Positive lag means
unit j leads unit i (j fires earlier). Negative lag means
unit i leads unit j (i fires earlier). Diagonal is always 0.
"""
rate_matrix = self.inst_Frate_data
num_units = self.inst_Frate_data.shape[0] # N
corr_matrix_this_event = np.full((num_units, num_units), np.nan)
lag_matrix_this_event = np.full((num_units, num_units), np.nan)
pairs = [(n1, n2) for n1 in range(num_units) for n2 in range(n1, num_units)]
def _compute_pair(pair):
n1, n2 = pair
return pair, compare_func(
rate_matrix[n1, :], rate_matrix[n2, :], max_lag=max_lag
)
n_workers = _resolve_n_jobs(n_jobs)
if n_workers > 1 and len(pairs) > 1:
with ThreadPoolExecutor(max_workers=n_workers) as pool:
results = pool.map(_compute_pair, pairs)
else:
results = map(_compute_pair, pairs)
for (n1, n2), (max_corr, max_lag_idx) in results:
corr_matrix_this_event[n1, n2] = max_corr
lag_matrix_this_event[n1, n2] = max_lag_idx
corr_matrix_this_event[n2, n1] = max_corr
lag_matrix_this_event[n2, n1] = -max_lag_idx
# Output is UxU, wrapped in PairwiseCompMatrix for API consistency
meta = {"compare_func": compare_func.__name__, "max_lag": max_lag}
return (
PairwiseCompMatrix(matrix=corr_matrix_this_event, metadata=meta),
PairwiseCompMatrix(matrix=lag_matrix_this_event, metadata=meta),
)
[docs]
def get_manifold(
self,
method: str = "PCA",
n_components: int = 2,
**kwargs,
):
"""Project the firing-rate data into a low-dimensional manifold using PCA or UMAP.
Parameters:
method (str): Which dimensionality reduction method to use.
Either ``"PCA"`` (default) or ``"UMAP"``.
n_components (int): Number of output dimensions to return
(default 2).
**kwargs: Additional options for UMAP. If method is ``"UMAP"``,
you can specify use_graph_communities (bool), return_labels
(bool), and other UMAP-specific keyword arguments such as
n_neighbors, min_dist, metric, or resolution.
Returns:
result (tuple): Depends on method and options:
If method is ``"PCA"``: ``(embedding, explained_variance_ratio,
components)`` where embedding has shape (T, n_components),
explained_variance_ratio has shape (n_components,), and
components has shape (n_components, U).
If method is ``"UMAP"``: ``(embedding, trustworthiness)`` where
embedding has shape (T, n_components) and trustworthiness is
a float from 0 to 1.
If method is ``"UMAP"`` with use_graph_communities=True and
return_labels=True: ``(embedding, labels, trustworthiness)``.
Notes:
- To visualise the resulting embedding, use
:func:`~spikelab.spikedata.plot_utils.plot_manifold`. It
accepts the embedding array directly and supports background
masks, continuous colour values, and discrete group colouring.
"""
if (
isinstance(n_components, (int, float, np.integer, np.floating))
and n_components <= 0
):
raise ValueError(f"n_components must be > 0, got {n_components}")
# Shape is (U, T); treat each time bin as a sample.
data_T = self.inst_Frate_data.T # (T, U)
method_upper = method.upper()
if method_upper == "PCA":
if kwargs:
warnings.warn(
f"Additional keyword arguments {list(kwargs.keys())} are ignored for method='{method}'.",
UserWarning,
)
return PCA_reduction(
data_T, n_components=n_components
) # (embedding, var_ratio, components)
if method_upper == "UMAP":
# Optional graph-based UMAP + Louvain communities.
use_graph_communities = kwargs.pop("use_graph_communities", False)
return_labels = kwargs.pop("return_labels", False)
if return_labels and not use_graph_communities:
warnings.warn(
"return_labels=True has no effect without use_graph_communities=True; "
"labels will not be returned.",
UserWarning,
stacklevel=2,
)
if use_graph_communities:
embedding, labels, tw = UMAP_graph_communities(
data_T,
n_components=n_components,
**kwargs,
)
if return_labels:
return embedding, labels, tw
return embedding, tw
# Default: plain UMAP embedding + trustworthiness.
return UMAP_reduction(
data_T,
n_components=n_components,
**kwargs,
) # (embedding, trustworthiness)
raise ValueError(
f"Unknown manifold method '{method}' (expected 'PCA' or 'UMAP')."
)