Source code for spikify.encoding.temporal.deconvolution.bens_spiker_algorithm
""".. raw:: html <h2>Bens Spiker Algorithm</h2>"""importnumpyasnpfromscipy.signal.windowsimportget_windowfrom.utilsimportWindowType
[docs]defbens_spiker(signal:np.ndarray,window_length:int|list[int],threshold:float|list[float],window_type:WindowType="boxcar",)->np.ndarray:""" Perform spike detection using Bens Spiker Algorithm. This function detects spikes in an input signal based on the comparison of cumulative errors calculated over a segment of the signal, which is filtered using a boxcar window. A spike is detected if the cumulative error between the filtered signal and the raw signal is below a certain threshold. Refer to the :ref:`bens_spiker_algorithm_desc` for a detailed explanation of the Ben's Spiker algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoding.temporal.deconvolution import bens_spiker signal = np.array([0.1, 0.2, 0.3, 1.0, 0.5, 0.3, 0.1]) window_length = 3 threshold = 0.5 spikes = bens_spiker(signal, window_length, threshold) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoding.temporal.deconvolution import bens_spiker >>> signal = np.array([0.1, 0.2, 0.3, 1.0, 0.5, 0.3, 0.1]) >>> window_length = 3 >>> threshold = 0.5 >>> spikes = bens_spiker(signal, window_length, threshold) >>> spikes array([0, 0, 1, 0, 0, 0, 0], dtype=int8) :param signal: The input signal to be analyzed. This should be a numpy ndarray. :type signal: numpy.ndarray :param window_length: The length of the window type filter window. Can be a int or a list of ints. :type window_length: int | list[int] :param threshold: Threshold value used to detect spikes. Can be a float or a list of floats. :type threshold: float | list[float] :return: A numpy array representing the detected spikes. :rtype: numpy.ndarray :raises ValueError: If the input signal is empty or if the window length is greater than the signal length. :raises TypeError: If the signal is not a numpy ndarray. """# Check for invalid inputsiflen(signal)==0:raiseValueError("Signal cannot be empty.")ifsignal.ndim==1:signal=signal.reshape(-1,1)S,F=signal.shapeifisinstance(window_length,int):window_lengths=[window_length]*Felifisinstance(window_length,list):ifnotall(isinstance(w,int)forwinwindow_length):raiseTypeError("All elements in window_length list must be integers.")window_lengths=window_lengthelse:raiseTypeError("Window lengths must be an int or a list of ints.")iflen(window_lengths)!=F:raiseValueError("Window lengths must match the number of features in the signal.")ifnp.any(np.array(window_lengths)>S):raiseValueError("All filter window sizes must be less than the length of the signal.")ifisinstance(threshold,float):thresholds=[threshold]*Felifisinstance(threshold,list):ifnotall(isinstance(w,float)forwinthreshold):raiseTypeError("All elements in threshold list must be float.")thresholds=thresholdelse:raiseTypeError("Threshold must be a float or a list of floats.")iflen(thresholds)!=F:raiseValueError("Threshold must match the number of features in the signal.")# Initialize the spike arrayspikes=np.zeros_like(signal,dtype=np.int8)# Create the boxcar filter windowfilter_window=[get_window(window_type,w)forwinwindow_lengths]# Copy of the signal to avoid modifying the original inputsignal_copy=np.copy(np.array(signal,dtype=np.float64))# Iterate over the signal to detect spikesforfeatinrange(F):forseq_idxinrange(len(signal[:,feat])-window_lengths[feat]+1):# Calculate errors using the filter windowsegment=signal_copy[seq_idx:seq_idx+window_lengths[feat],feat]error1=np.sum(np.abs(segment-filter_window[feat]),axis=0)error2=np.sum(np.abs(segment),axis=0)# Update signal and spike array if a spike is detectediferror1<=(error2-thresholds[feat]):signal_copy[seq_idx:seq_idx+window_lengths[feat],feat]-=filter_window[feat]spikes[seq_idx,feat]=1ifF==1:spikes=spikes.flatten()returnspikes