Source code for spikify.encoders.temporal.latency.burst_coding_algorithm

"""
.. raw:: html

    <h2>Burst Coding Algorithm</h2>
"""

import numpy as np


[docs] def burst_coding(signal: np.ndarray, n_max: int, t_min: int, t_max: int, interval_length: int) -> np.ndarray: """ Perform Burst Coding (BC) on the input signal. This function implements a biologically inspired burst coding scheme where each input intensity (normalized P ∈ [0, 1]) is converted into a burst of spikes: - Number of spikes N_s(P) = ceil(P * n_max) - ISI(P): Fixed at t_max when N_s = 1; otherwise linearly mapped between t_min (strong input) and t_max (weak input) for N_s > 1. The signal is divided into non-overlapping blocks of length ``interval_length``. The mean intensity per block determines the burst parameters (N_s and ISI). The burst is placed at regular intervals starting from the beginning of the output block. Higher intensities produce bursts with more spikes and smaller ISIs. The ``interval_length`` must be large enough to fit the longest burst. .. note:: - BC requires normalized input signals between 0 and 1. If the input signal contains negative values, they are shifted to be non-negative and then normalized. Refer to the :ref:`burst_coding_algorithm_desc` for a detailed explanation of the Burst Coding Algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoders.temporal.latency import burst_coding np.random.seed(42) signal = np.random.rand(16) n_max = 4 t_min = 2 t_max = 6 length = 16 encoded_signal = burst_coding(signal, n_max, t_min, t_max, length) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoders.temporal.latency import burst_coding >>> np.random.seed(42) >>> signal = np.random.rand(16) >>> n_max = 4 >>> t_min = 2 >>> t_max = 6 >>> length = 16 >>> encoded_signal = burst_coding(signal, n_max, t_min, t_max, length) >>> encoded_signal array([1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int8) :param signal: Input signal to encode (1D or 2D: timestamps × features). :type signal: numpy.ndarray :param n_max: Maximum number of spikes in a burst (for P = 1). Must be ≥ 1. :type n_max: int :param t_min: Minimum inter-spike interval (ISI) in time steps (for strong inputs with N_s > 1). :type t_min: int :param t_max: Maximum inter-spike interval (ISI) in time steps (for weak inputs or N_s = 1). :type t_max: int :param interval_length: Length of each output block corresponding to one input block. Must divide the signal length evenly and be large enough to fit the longest burst (approximately n_max * (t_min + 1)). :type interval_length: int :return: A numpy array representing the encoded spike train. :rtype: numpy.ndarray :raises ValueError: If signal is empty, interval_length does not divide signal length, or interval_length is too small for the longest possible burst. """ # Check for empty signal if len(signal) == 0: raise ValueError("Signal cannot be empty.") # Ensure 2D processing (T, F) if signal.ndim == 1: signal = signal.reshape(-1, 1) T, F = signal.shape if T % interval_length != 0: raise ValueError(f"The interval_length ({interval_length}) is not a factor of the signal length ({T}).") signal_copy = signal.copy() # signal_copy = np.clip(signal_copy, 0, None) shift = signal_copy.min(axis=0) shift[shift > 0] = 0 # only shift if negative values are present signal_copy -= shift # Compute max amplitude per feature to be used for scaling if max amplitude is grater than 1 max_amp = signal_copy.max(axis=0) # Find features that require scaling features_to_scale = np.where(max_amp > 1)[0] for f in features_to_scale: signal_copy[:, f] /= max_amp[f] signal_copy = np.mean(signal_copy.reshape(-1, interval_length, F), axis=1) spike_num = np.ceil(signal_copy * n_max).astype(int) ISI = np.ceil(t_max - signal_copy * (t_max - t_min)).astype(int) required_length = np.max(spike_num * (ISI + 1)) if interval_length < required_length: raise ValueError(f"Invalid stream length, the min length is {required_length}") spikes = np.zeros((T // interval_length, interval_length, F), dtype=np.int8) for i in range(signal_copy.shape[0]): for f in range(F): spike_times = np.arange(0, spike_num[i, f] * (ISI[i, f] + 1), ISI[i, f] + 1) spikes[i, spike_times[spike_times < interval_length], f] = 1 spikes = spikes.reshape(-1, F) # Flatten if input was 1D if F == 1: spikes = spikes.flatten() return spikes