Source code for spikify.encoding.temporal.contrast.step_forward_algorithm
""".. raw:: html <h2>Step Forward Algorithm</h2>"""importnumpyasnp
[docs]defstep_forward(signal:np.ndarray,threshold:float|list[float])->np.ndarray:""" Perform Step-Forward encoding on the input signal. This function takes a continuous signal and converts it into a spike train using a dynamically updated baseline and threshold-based approach. A spike is generated when the signal exceeds or drops below the dynamically adjusted baseline (`Base`) by the specified `Threshold`. Refer to the :ref:`step_forward_algorithm_desc` for a detailed explanation of the Step-Forward encoding algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoding.temporal.contrast import step_forward signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) threshold = 0.2 encoded_signal = step_forward(signal, threshold) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoding.temporal.contrast import step_forward >>> signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) >>> threshold = 0.2 >>> encoded_signal = step_forward(signal, threshold) >>> encoded_signal array([0, 0, 1, 0, 0, 1], dtype=int8) :param signal: The input signal to be encoded. This should be a numpy ndarray. :type signal: numpy.ndarray :param threshold: The threshold value(s) for spike detection. Can be a float or a list of floats. :type threshold: float | list[float] :return: A numpy array representing the encoded spike train. :rtype: numpy.ndarray :raises ValueError: If the input signal is empty. :raises TypeError: If the signal is not a numpy ndarray. """iflen(signal)==0:raiseValueError("Signal cannot be empty.")ifsignal.ndim==1:signal=signal.reshape(-1,1)S,F=signal.shapeifisinstance(threshold,float):thresholds=[threshold]*Felifisinstance(threshold,list):ifnotall(isinstance(w,float)forwinthreshold):raiseTypeError("All elements in threshold list must be float.")thresholds=thresholdelse:raiseTypeError("Threshold must be a float or a list of floats.")iflen(thresholds)!=F:raiseValueError("Threshold must match the number of features in the signal.")spike=np.zeros_like(signal,dtype=np.int8)# Base value initialized at the start of the signalforfeatinrange(F):base=signal[0,feat]forvalue_idx,valueinenumerate(signal[:,feat]):ifvalue>base+thresholds[feat]:spike[value_idx,feat]=1base+=thresholds[feat]elifvalue<base-thresholds[feat]:spike[value_idx,feat]=-1base-=thresholds[feat]ifF==1:spike=spike.flatten()returnspike