Source code for spikify.encoders.temporal.contrast.threshold_based_algorithm
""".. raw:: html <h2>Threshold Based Representation Algorithm</h2>"""importnumpyasnp
[docs]defthreshold_based_representation(signal:np.ndarray,factor:float|int|list[float|int]|np.ndarray)->tuple[np.ndarray,np.ndarray]:""" Perform Threshold-Based Representation (TBR) encoding on the input signal. This function takes a continuous signal and converts it into a spike train using a fixed threshold based on the signal's variations. A spike is generated when the variation exceeds the computed threshold. Refer to the :ref:`threshold_based_representation_algorithm_desc` for a detailed explanation of the TBR encoding algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoders.temporal.contrast import threshold_based_representation signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) factor = 0.5 encoded_signal, threshold = threshold_based_representation(signal, factor) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoders.temporal.contrast import threshold_based_representation >>> signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) >>> factor = 0.5 >>> encoded_signal, threshold = threshold_based_representation(signal, factor) >>> encoded_signal array([ 1, 0, -1, 1, 0, 0], dtype=int8) :param signal: The input signal to be encoded. This should be a numpy ndarray. :type signal: numpy.ndarray :param factor: The factor value (`factor`) that controls the noise-reduction threshold. Can be a float, an integer, or a list of floats or integers. :type factor: float | int | list[float | int] | numpy.ndarray :return: A tuple containing the encoded spike train and the computed threshold for each feature. :rtype: tuple[numpy.ndarray, numpy.ndarray] :raises ValueError: If the input signal is empty or if the factor length does not match the number of features. :raises TypeError: If the factor parameter is of invalid dimension. """# Input validationiflen(signal)==0:raiseValueError("Signal cannot be empty.")# Ensure 2D processing (T, F)ifsignal.ndim==1:signal=signal.reshape(-1,1)T,F=signal.shape# Handle factorifnp.isscalar(factor):factors=np.full(F,float(factor))else:factors=np.asarray(factor,dtype=float)iffactors.ndim!=1:raiseTypeError("Factor must be a scalar or a 1D sequence of numbers.")iffactors.size!=F:raiseValueError("Factor must match the number of features in the signal.")spike=np.zeros_like(signal,dtype=np.int8)# Compute variation exactly as in the original code# diff[t] = s[t+1] - s[t] for t = 0 to T-2# diff[T-1] = diff[T-2] (last value set to second-last)diff=np.diff(signal,axis=0,append=signal[[0],:])# append first value of signal to maintain shapediff[-1,:]=diff[-2,:]# force last to equal second-last# Compute threshold per feature (over all T variations, including the duplicated last)threshold=np.mean(diff,axis=0)+factors*np.std(diff,axis=0)# Generate spikes: compare on the full diff array (length S)threshold=threshold.reshape(1,threshold.shape[0])spike[diff>threshold]=1spike[diff<-threshold]=-1# Flatten if input was 1DifF==1:spike=spike.flatten()returnspike,threshold.flatten()