Source code for spikify.encoders.temporal.global_referenced.phase_encoding_algorithm

"""
.. raw:: html

    <h2>Phase Encoding Algorithm</h2>
"""

import numpy as np


[docs] def phase(signal: np.ndarray, num_bits: int) -> np.ndarray: """ Perform Phase Encoding (PE) on the input signal. This function encodes the input signal by calculating the phase angles of the normalized signal and quantizing these angles into a binary spike train representation. The encoding process uses a specified number of bits to determine the level of quantization. .. note:: - PE requires normalized input signals between 0 and 1. If the input signal contains negative values, they are shifted to be non-negative and then normalized. Refer to the :ref:`phase_encoding_algorithm_desc` for a detailed explanation of the Phase Encoding Algorithm. **Code Example:** .. code-block:: python import numpy as np from spikify.encoders.temporal.global_referenced import phase signal = np.array([0.1, 0.2, 0.3, 1.0, 0.5, 0.3, 0.1, 0.2]) num_bits = 4 encoded_signal = phase(signal, num_bits) .. doctest:: :hide: >>> import numpy as np >>> from spikify.encoders.temporal.global_referenced import phase >>> signal = np.array([0.1, 0.2, 0.3, 1.0, 0.5, 0.3, 0.1, 0.2]) >>> num_bits = 4 >>> encoded_signal = phase(signal, num_bits) >>> encoded_signal array([1, 1, 1, 1, 1, 0, 0, 0], dtype=uint8) :param signal: The input signal to be encoded. This should be a numpy ndarray. :type signal: numpy.ndarray :param num_bits: The number of bits to use for encoding. :type num_bits: int :return: A 1D numpy array representing the phase-encoded spike train. :rtype: numpy.ndarray :raises ValueError: If the input signal is empty or if the number of bits is not appropriate for the signal length. """ # Check for invalid inputs if len(signal) == 0: raise ValueError("Signal cannot be empty.") # Ensure 2D processing (T, F) if signal.ndim == 1: signal = signal.reshape(-1, 1) T, F = signal.shape if T % num_bits != 0: raise ValueError(f"The phase_encoding num_bits ({num_bits}) is not a factor of the signal length ({T}).") signal_copy = signal.copy() # Normalize signal if signal has negative values shift = signal_copy.min(axis=0) shift[shift > 0] = 0 # only shift if negative values are present signal_copy -= shift # Compute max amplitude per feature to be used for scaling if max amplitude is grater than 1 max_amp = signal_copy.max(axis=0) # Find features that require scaling features_to_scale = np.where(max_amp > 1)[0] for f in features_to_scale: signal_copy[:, f] /= max_amp[f] # Compute mean over the signal reshaped to bit-sized chunks interval_bit_mean = np.mean(signal_copy.reshape(T // num_bits, num_bits, F), axis=1) max_amp = interval_bit_mean.max(axis=0) # Find features that require scaling features_to_scale = np.where(max_amp > 0)[0] for f in features_to_scale: interval_bit_mean[:, f] /= max_amp[f] phase = np.arcsin(interval_bit_mean) bins = np.linspace(0, np.pi / 2, 2**num_bits + 1) levels = np.searchsorted(bins, phase) # Adjust levels to avoid out-of-range values levels = np.clip(levels, 0, 2**num_bits - 1) spikes = np.zeros((T, F), dtype=np.uint8) # Shift and extract bits # Each integer is represented in binary using `num_bits` bits. # The signal (levels) is right-shifted bit-by-bit to bring each bit position to the least significant bit, # then masked with &1 to extract it (1 if set, 0 otherwise). bits_arr = ((levels[..., None] >> np.arange(num_bits - 1, -1, -1)) & 1).astype(np.uint8) spikes = bits_arr.transpose(0, 2, 1).reshape(T, F) # Flatten if input was 1D if F == 1: spikes = spikes.flatten() return spikes