"""
.. raw:: html
<h2>Moving Window Algorithm</h2>
"""
import numpy as np
[docs]
def moving_window(
signal: np.ndarray, window_length: int, threshold: float | int | list[float | int] | np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""
Perform Moving Window (MW) encoding on the input signal.
This function takes a continuous signal and converts it into a spike train using a moving window and
threshold-based approach. A spike is generated when the signal exceeds the calculated `base` plus or minus a
specified `threshold`.
Refer to the :ref:`moving_window_algorithm_desc` for a detailed explanation of the Moving Window encoding
algorithm.
**Code Example:**
.. code-block:: python
import numpy as np
from spikify.encoders.temporal.contrast import moving_window
signal = np.array([0.1, 0.3, 0.2, 0.5, 0.8, 1.0])
window_length = 3
threshold = 0.2
encoded_signal, thresholds = moving_window(signal, window_length, threshold)
.. doctest::
:hide:
>>> import numpy as np
>>> from spikify.encoders.temporal.contrast import moving_window
>>> signal = np.array([0.1, 0.3, 0.2, 0.5, 0.8, 1.0])
>>> window_length = 3
>>> threshold = 0.2
>>> encoded_signal, _ = moving_window(signal, window_length, threshold)
>>> encoded_signal.flatten()
array([0, 0, 0, 1, 1, 1], dtype=int8)
:param signal: Input signal to encode (1D or 2D: time × features or channels).
:type signal: numpy.ndarray
:param window_length: The size of the sliding window for calculating the signal base mean.
:type window_length: int
:param threshold: Threshold(s) for spike generation; scalar or 1D sequence matching features.
:type threshold: float | int | list[float | int] | numpy.ndarray
:return:
- spikes: A numpy array representing the encoded spike train. (values in {-1, 0, +1})
- thresholds: Per-feature or channel thresholds used for encoding, returned for use in decoding,
shape (features or channels,).
:rtype: tuple[numpy.ndarray, numpy.ndarray]
:raises ValueError: If the input signal is empty or if the threshold dimensions do not match the signal f
eature dimensions.
"""
# Check for empty signal
if len(signal) == 0:
raise ValueError("Signal cannot be empty.")
# Ensure 2D processing (T, F)
if signal.ndim == 1:
signal = signal.reshape(-1, 1)
T, F = signal.shape
# Handle threshold
if np.isscalar(threshold):
thresholds = np.full(F, float(threshold))
else:
thresholds = np.asarray(threshold, dtype=float)
if thresholds.ndim != 1:
raise ValueError("Threshold must be a scalar or a 1D sequence of numbers.")
if thresholds.size != F:
raise ValueError("Threshold must match the number of features in the signal.")
spikes = np.zeros_like(signal, dtype=np.int8)
# First loop: t = 0 : window_length
# For the first window_length samples, use the mean of available samples as base signal otherwise
# the first window_length samples will not be encoded since there are not enough samples to fill the window
for f in range(F):
base = np.mean(signal[:window_length, f])
for t in range(window_length):
if signal[t, f] > base + thresholds[f]:
spikes[t, f] = 1
elif signal[t, f] < base - thresholds[f]:
spikes[t, f] = -1
# Second loop: t = window_length : T
# For the rest of the signal, use the moving window to calculate the base signal
for f in range(F):
for t in range(window_length, T):
base = np.mean(signal[t - window_length : t, f])
if signal[t, f] > base + thresholds[f]:
spikes[t, f] = 1
elif signal[t, f] < base - thresholds[f]:
spikes[t, f] = -1
return spikes, thresholds