Source code for spikify.encoding.temporal.latency.burst_encoding_algorithm

"""
.. raw:: html

    <h2>Burst Encoding Algorithm</h2>
"""

import numpy as np


[docs] def burst_encoding(signal: np.ndarray, n_max: int, t_min: int, t_max: int, length: int) -> np.ndarray: """ Perform burst encoding on the input signal. This function encodes the input signal by generating bursts of spikes based on the number of spikes and the inter-spike interval (ISI). The encoding process takes into account the maximum number of spikes (n_max), the minimum (t_min) and maximum (t_max) ISI, and the desired length of the encoded signal. Refer to the :ref:`burst_encoding_algorithm_desc` for a detailed explanation of the Burst Encoding Algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoding.temporal.latency import burst_encoding np.random.seed(42) signal = np.random.rand(16) n_max = 4 t_min = 2 t_max = 6 length = 16 encoded_signal = burst_encoding(signal, n_max, t_min, t_max, length) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoding.temporal.latency import burst_encoding >>> np.random.seed(42) >>> signal = np.random.rand(16) >>> n_max = 4 >>> t_min = 2 >>> t_max = 6 >>> length = 16 >>> encoded_signal = burst_encoding(signal, n_max, t_min, t_max, length) >>> encoded_signal array([1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=int8) :param signal: The input signal to be encoded. This should be a numpy ndarray. :type signal: numpy.ndarray :param n_max: The maximum number of spikes in each burst. :type n_max: int :param t_min: The minimum inter-spike interval (ISI). :type t_min: int :param t_max: The maximum inter-spike interval (ISI). :type t_max: int :param length: The total length of the encoded signal. :type length: int :return: A 1D numpy array representing the burst-encoded spike train. :rtype: numpy.ndarray :raises ValueError: If the input signal is empty or the length is not a multiple of the signal length. :raises ValueError: If the required spike train length exceeds the provided length. """ is_1d = signal.ndim == 1 if is_1d: signal = signal[:, None] S, F = signal.shape if S == 0: raise ValueError("Signal cannot be empty.") if S % length != 0: raise ValueError( f"The burst_encoding length ({length}) is not a multiple of the signal length ({len(signal)})." ) signal = np.clip(signal, 0, None) signal = np.mean(signal.reshape(-1, length, F), axis=1) signal_max = signal.max(axis=0) signal_max[signal_max == 0] = 1 signal /= signal_max spike_num = np.ceil(signal * n_max).astype(int) ISI = np.ceil(t_max - signal * (t_max - t_min)).astype(int) required_length = np.max(spike_num * (ISI + 1)) if length < required_length: raise ValueError(f"Invalid stream length, the min length is {required_length}") spikes = np.zeros((signal.shape[0], length, signal.shape[1]), dtype=np.int8) for i in range(signal.shape[0]): for f in range(signal.shape[1]): spike_times = np.arange(0, spike_num[i, f] * (ISI[i, f] + 1), ISI[i, f] + 1) spikes[i, spike_times[spike_times < length], f] = 1 spikes = spikes.reshape(-1, signal.shape[1]) return spikes.ravel() if is_1d else spikes