Source code for spikify.encoders.temporal.contrast.threshold_based_algorithm

"""
.. raw:: html

    <h2>Threshold Based Representation Algorithm</h2>
"""

import numpy as np


[docs] def threshold_based_representation( signal: np.ndarray, factor: float | int | list[float | int] | np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """ Perform Threshold-Based Representation (TBR) encoding on the input signal. This function takes a continuous signal and converts it into a spike train using a fixed threshold based on the signal's variations. A spike is generated when the variation exceeds the computed threshold. Refer to the :ref:`threshold_based_representation_algorithm_desc` for a detailed explanation of the TBR encoding algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoders.temporal.contrast import threshold_based_representation signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) factor = 0.5 encoded_signal, threshold = threshold_based_representation(signal, factor) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoders.temporal.contrast import threshold_based_representation >>> signal = np.array([0.1, 0.3, 0.4, 0.2, 0.5, 0.6]) >>> factor = 0.5 >>> encoded_signal, _ = threshold_based_representation(signal, factor) >>> encoded_signal.flatten() array([ 1, 0, -1, 1, 0, 0], dtype=int8) :param signal: Input signal to encode (1D or 2D: time × features or channels). :type signal: numpy.ndarray :param factor: The factor value (`factor`) that controls the noise-reduction threshold. Can be a float, an integer, or a list of floats or integers. :type factor: float | int | list[float | int] | numpy.ndarray :return: - spikes: A numpy array representing the encoded spike train. (values in {-1, 0, +1}) - thresholds: Per-feature or channel thresholds used for encoding, returned for use in decoding, shape (features or channels,). :rtype: tuple[numpy.ndarray, numpy.ndarray] :raises ValueError: If the input signal is empty or if the factor length does not match the number of features. """ # Input validation if len(signal) == 0: raise ValueError("Signal cannot be empty.") # Ensure 2D processing (T, F) if signal.ndim == 1: signal = signal.reshape(-1, 1) T, F = signal.shape # Handle factor if np.isscalar(factor): factors = np.full(F, float(factor)) else: factors = np.asarray(factor, dtype=float) if factors.ndim != 1: raise ValueError("Factor must be a scalar or a 1D sequence of numbers.") if factors.size != F: raise ValueError("Factor must match the number of features in the signal.") spike = np.zeros_like(signal, dtype=np.int8) # Compute variation exactly as in the original code # diff[t] = s[t+1] - s[t] for t = 0 to T-2 # diff[T-1] = diff[T-2] (last value set to second-last) diff = np.diff(signal, axis=0, append=signal[[0], :]) # append first value of signal to maintain shape diff[-1, :] = diff[-2, :] # force last to equal second-last # Compute threshold per feature (over all T variations, including the duplicated last) threshold = np.mean(diff, axis=0) + factors * np.std(diff, axis=0) # Generate spikes: compare on the full diff array (length S) threshold = threshold.reshape(1, threshold.shape[0]) spike[diff > threshold] = 1 spike[diff < -threshold] = -1 # Reshape threshold to 1D for return threshold = threshold.flatten() return spike, threshold