Source code for spikify.encoders.temporal.deconvolution.modified_hough_spiker_algorithm

"""
.. raw:: html

    <h2>Modified Hough Spiker Algorithm</h2>
"""

import numpy as np
from scipy.signal import firwin
from .utils import WindowType


[docs] def modified_hough_spiker( signal: np.ndarray, window_length: int, cutoff: float | np.ndarray, threshold: float | int | list[float, int] | np.ndarray, width: int | None = None, window_type: WindowType = "hann", pass_zero: bool | str = True, scale: bool = True, fs: float | None = None, ) -> np.ndarray: """ Perform Modified Hough Spiker Algorithm (MHSA) encoding on the input signal. The Modified Hough Spiker Algorithm (MHSA) is basically an improved version of the original Hough Spiker Algorithm (HSA). The original HSA is very strict: it only allows a spike to be created if the current piece of the signal is exactly as big as or bigger than the filter pattern at every single point in that window. If even one tiny spot dips below the filter, no spike is detected. MHSA relaxes this rule a little bit to make it more practical and flexible. Instead of demanding perfection everywhere, it allows a small amount of "shortfall" — places where the signal is slightly below the filter. It measures how much the signal falls short in those spots (by adding up the differences where the filter is higher than the signal), and if that total shortfall is small enough (below a chosen limit called the threshold), it still decides to create a spike there. Then, just like in the original, it subtracts the filter pattern from the signal to remove the detected spike pattern so the algorithm can keep looking for the next one. .. note:: - MHSA requires non-negative inputs; automatic shifting and normalization to [0, 1] is applied per feature. - The FIR filter is designed using `scipy.signal.firwin` with the specified cutoff, window type, etc. - For multi-feature signals, the same filter shape is applied across all features, but scaling is performed independently per feature based on its amplitude. Refer to the :ref:`modified_hough_spiker_algorithm_desc` for a detailed explanation of the Modified Hough Spiker Algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoders.temporal.deconvolution import modified_hough_spiker signal = np.array([0.1, 0.2, 0.3, 1.0, 0.5, 0.3, 0.1]) window_length = 3 threshold = 0.5 cutoff = 0.1 encoded_signal, fir_coeffs, shift, norm, = modified_hough_spiker(signal, window_length, cutoff, threshold) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoders.temporal.deconvolution import modified_hough_spiker >>> signal = np.array([0.1, 0.2, 0.3, 1.0, 0.5, 0.3, 0.1]) >>> window_length = 3 >>> threshold = 0.5 >>> cutoff = 0.1 >>> encoded_signal, _, _, _ = modified_hough_spiker(signal, window_length, cutoff, threshold) >>> encoded_signal.flatten() array([0, 0, 1, 1, 0, 0, 1], dtype=int8) :param signal: Input signal to encode (1D or 2D: time × features or channels). :type signal: numpy.ndarray :param window_length: Length of the FIR filter (number of coefficients). :type window_length: int :param cutoff: Cutoff frequency(ies) for the FIR filter design (normalized 0 to 1, where 1 = Nyquist). Scalar or an array of cutoff frequencies (that is, band edges). :type cutoff: float | numpy.ndarray :param threshold: Threshold factor for spike detection. Scalar or per-feature sequence. :type threshold: float | int | list[float | int] | numpy.ndarray :param width: Transition width for FIR filter design (optional, used with certain window types). :type width: int | None :param window_type: Window function for FIR filter design (e.g., 'hann', 'hamming', 'blackman', 'boxcar'). :type window_type: str :param pass_zero: Whether the filter should be low-pass (True) or high-pass (False/'highpass'). :type pass_zero: bool | str :param scale: Set to True to scale the coefficients so that the frequency response is exactly unity at a certain frequency. :type scale: bool :param fs: Sampling frequency (used for physical frequency units in cutoff; optional). :type fs: float | None :return: - spikes: A numpy array representing the encoded spike train. (values in {0, +1}) - fir_bank: Final filter coefficients used, shape (window_length, features or channels). - shift: Per-feature shift values subtracted to make signal non-negative, shape (features or channels,). - norm: Per-feature normalization values used to scale signal to [0, 1], shape (features or channels,). :rtype: tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray] :raises ValueError: If the input signal is empty or if the threshold dimensions do not match the signal features or if the window_length is greater than the signal lenght. """ # Check for empty signal if len(signal) == 0: raise ValueError("Signal cannot be empty.") # Ensure 2D processing (T, F) if signal.ndim == 1: signal = signal.reshape(-1, 1) T, F = signal.shape # Handle window_length if window_length > T: raise ValueError("window_length must be less than the number of time steps in the signal.") # Handle threshold if np.isscalar(threshold): thresholds = np.full(F, float(threshold)) else: thresholds = np.asarray(threshold, dtype=float) if thresholds.ndim != 1: raise ValueError("Threshold must be a scalar or a 1D sequence of numbers.") if thresholds.size != F: raise ValueError("Threshold must match the number of features in the signal.") spikes = np.zeros_like(signal, dtype=np.int8) # Generate filter coefficient values according to their window length for each feature fir = firwin(window_length, cutoff, width=width, window=window_type, pass_zero=pass_zero, scale=scale, fs=fs) # Stack the same filter for all features in case we need to modify coeffiecient due to signal # amplitude for certain features fir_bank = np.stack([fir] * F, axis=0).T signal_copy = np.copy(np.array(signal, dtype=np.float64)) # Normalize signal if signal has negative values shift = signal_copy.min(axis=0) shift[shift > 0] = 0 # only shift if negative values are present signal_copy -= shift norm = signal_copy.max(axis=0) norm[norm <= 1] = 1 # only normalize if max is greater than 1 signal_copy /= norm for f in range(F): for t in range(0, T): # Determine the end index for the current window end_index = min(t + window_length, T) # Extract the relevant segment of the signal and the corresponding filter window signal_segment = signal_copy[t:end_index, f] filter_segment = fir_bank[: end_index - t, f] # Calculate the error for this segment error = np.sum(np.maximum(filter_segment - signal_segment, 0)) # If the cumulative error is within the threshold, a spike is detected if error <= thresholds[f]: signal_copy[t:end_index, f] -= filter_segment spikes[t, f] = 1 return spikes, fir_bank, shift, norm