Source code for spikify.encoders.temporal.contrast.step_forward_algorithm
""".. raw:: html <h2>Step Forward Algorithm</h2>"""importnumpyasnp
[docs]defstep_forward(signal:np.ndarray,threshold:float|int|list[float|int]|np.ndarray)->tuple[np.ndarray,np.ndarray]:""" Perform Step-Forward (SF) encoding on the input signal. This function takes a continuous signal and converts it into a spike train using a dynamically updated baseline signal and threshold-based approach. A spike is generated when the signal exceeds or drops below the dynamically adjusted baseline (`base`) by the specified `threshold`. Refer to the :ref:`step_forward_algorithm_desc` for a detailed explanation of the SF encoding algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoders.temporal.contrast import step_forward signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) threshold = 0.2 encoded_signal, thresholds = step_forward(signal, threshold) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoders.temporal.contrast import step_forward >>> signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) >>> threshold = 0.2 >>> encoded_signal, _ = step_forward(signal, threshold) >>> encoded_signal.flatten() array([0, 0, 1, 0, 0, 1], dtype=int8) :param signal: Input signal to encode (1D or 2D: time × features or channels). :type signal: numpy.ndarray :param threshold: Threshold(s) for spike generation; scalar or 1D sequence matching features. :type threshold: float | int | list[float | int] | numpy.ndarray :return: - spikes: A numpy array representing the encoded spike train. (values in {-1, 0, +1}) - thresholds: Per-feature or channel thresholds used for encoding, returned for use in decoding, shape (features or channels,). :rtype: tuple[numpy.ndarray, numpy.ndarray] :raises ValueError: If the input signal is empty or if the threshold dimensions do not match the signal features dimensions. """# Input validationiflen(signal)==0:raiseValueError("Signal cannot be empty.")# Ensure 2D processing (T, F)ifsignal.ndim==1:signal=signal.reshape(-1,1)T,F=signal.shape# Handle thresholdifnp.isscalar(threshold):thresholds=np.full(F,float(threshold))else:thresholds=np.asarray(threshold,dtype=float)ifthresholds.ndim!=1:raiseValueError("Threshold must be a scalar or a 1D sequence of numbers.")ifthresholds.size!=F:raiseValueError("Threshold must match the number of features in the signal.")spike=np.zeros_like(signal,dtype=np.int8)# base signal initialized at the start of the signalbase=signal[0,:]# Iterate over signal values skipping the first timestep since it's used for initialization of the base signalforfeatinrange(F):base=signal[0,feat]fortinrange(1,T):value=signal[t,feat]ifvalue>base+thresholds[feat]:spike[t,feat]=1base+=thresholds[feat]elifvalue<base-thresholds[feat]:spike[t,feat]=-1base-=thresholds[feat]returnspike,thresholds