"""
Short-Time Fourier Transform (STFT) and spectrogram computation.
The STFT provides time-frequency analysis of signals by computing the Fourier
transform of short, overlapping segments of the signal. This reveals how the
frequency content of a signal changes over time.
Functions
---------
- stft: Compute Short-Time Fourier Transform
- istft: Inverse Short-Time Fourier Transform
- spectrogram: Compute power spectrogram
- get_window: Generate window functions
References
----------
.. [1] Allen, J. (1977). Short term spectral analysis, synthesis, and
modification by discrete Fourier transform. IEEE Transactions on
Acoustics, Speech, and Signal Processing, 25(3), 235-238.
.. [2] Griffin, D., & Lim, J. (1984). Signal estimation from modified
short-time Fourier transform. IEEE Transactions on Acoustics,
Speech, and Signal Processing, 32(2), 236-243.
"""
from typing import Any, NamedTuple, Optional, Union
import numpy as np
from numpy.typing import ArrayLike, NDArray
from scipy import signal as scipy_signal
# =============================================================================
# Result Types
# =============================================================================
[docs]
class STFTResult(NamedTuple):
"""
Result of Short-Time Fourier Transform.
Attributes
----------
frequencies : ndarray
Frequency values in Hz.
times : ndarray
Time values in seconds (segment centers).
Zxx : ndarray
STFT matrix (complex), shape (n_frequencies, n_times).
"""
frequencies: NDArray[np.floating]
times: NDArray[np.floating]
Zxx: NDArray[np.complexfloating]
[docs]
class Spectrogram(NamedTuple):
"""
Result of spectrogram computation.
Attributes
----------
frequencies : ndarray
Frequency values in Hz.
times : ndarray
Time values in seconds.
power : ndarray
Power spectrogram (|STFT|^2).
"""
frequencies: NDArray[np.floating]
times: NDArray[np.floating]
power: NDArray[np.floating]
# =============================================================================
# Window Functions
# =============================================================================
[docs]
def get_window(
window: Union[str, tuple[str, Any], ArrayLike],
length: int,
fftbins: bool = True,
) -> NDArray[np.floating]:
"""
Generate a window function.
Parameters
----------
window : str, tuple, or array_like
Window type. Can be:
- String: 'hann', 'hamming', 'blackman', 'bartlett', 'kaiser', etc.
- Tuple: (window_name, parameter) for parameterized windows
- Array: Custom window values
length : int
Length of the window.
fftbins : bool, optional
If True, create a periodic window for FFT use. Default is True.
Returns
-------
window : ndarray
Window function values.
Examples
--------
>>> w = get_window('hann', 256)
>>> len(w)
256
>>> w[0], w[-1] # Near-zero at edges
(0.0, 0.0038...)
>>> w = get_window(('kaiser', 8.0), 256) # Kaiser with beta=8
>>> len(w)
256
Notes
-----
Common window functions:
- 'rectangular': No tapering (unity)
- 'hann': Good frequency resolution, low leakage
- 'hamming': Similar to Hann, slightly different sidelobes
- 'blackman': Very low sidelobes, wider main lobe
- 'kaiser': Parameterized trade-off between resolution and leakage
"""
if isinstance(window, (list, np.ndarray)):
return np.asarray(window, dtype=np.float64)
return scipy_signal.get_window(window, length, fftbins=fftbins)
[docs]
def window_bandwidth(
window: Union[str, ArrayLike],
length: int,
) -> float:
"""
Compute the equivalent noise bandwidth of a window.
The equivalent noise bandwidth (ENBW) is the width of an ideal rectangular
filter that would pass the same amount of white noise power.
Parameters
----------
window : str or array_like
Window function.
length : int
Window length.
Returns
-------
enbw : float
Equivalent noise bandwidth in bins.
Examples
--------
>>> enbw = window_bandwidth('hann', 256)
>>> 1.4 < enbw < 1.6 # Hann window ENBW is about 1.5 bins
True
"""
if isinstance(window, str):
w = get_window(window, length)
else:
w = np.asarray(window, dtype=np.float64)
# ENBW = N * sum(w^2) / sum(w)^2
enbw = length * np.sum(w**2) / np.sum(w) ** 2
return float(enbw)
# =============================================================================
# STFT Functions
# =============================================================================
[docs]
def stft(
x: ArrayLike,
fs: float = 1.0,
window: Union[str, tuple[str, Any], ArrayLike] = "hann",
nperseg: int = 256,
noverlap: Optional[int] = None,
nfft: Optional[int] = None,
detrend: Union[str, bool] = False,
return_onesided: bool = True,
boundary: Optional[str] = "zeros",
padded: bool = True,
) -> STFTResult:
"""
Compute the Short-Time Fourier Transform.
Parameters
----------
x : array_like
Input time-domain signal.
fs : float, optional
Sampling frequency in Hz. Default is 1.0.
window : str, tuple, or array_like, optional
Window function. Default is 'hann'.
nperseg : int, optional
Length of each segment. Default is 256.
noverlap : int, optional
Number of points to overlap between segments.
Default is nperseg // 2.
nfft : int, optional
Length of the FFT used. Default is nperseg.
detrend : str or bool, optional
Detrending: 'constant', 'linear', or False. Default is False.
return_onesided : bool, optional
If True, return only non-negative frequencies for real input.
Default is True.
boundary : str or None, optional
Boundary extension: 'zeros', 'even', 'odd', or None.
Default is 'zeros'.
padded : bool, optional
Whether to pad the signal. Default is True.
Returns
-------
result : STFTResult
Named tuple with frequencies, times, and STFT matrix.
Examples
--------
>>> import numpy as np
>>> fs = 1000
>>> t = np.arange(0, 1, 1/fs)
>>> x = np.sin(2 * np.pi * 50 * t) # 50 Hz sine
>>> result = stft(x, fs=fs, nperseg=128)
>>> result.Zxx.shape # (n_freq, n_time)
(65, 16)
Notes
-----
The STFT provides a time-frequency representation where:
- Time resolution = nperseg / fs
- Frequency resolution = fs / nfft
There is a trade-off between time and frequency resolution (uncertainty
principle): better time resolution requires shorter segments, which
reduces frequency resolution, and vice versa.
"""
x = np.asarray(x, dtype=np.float64)
if noverlap is None:
noverlap = nperseg // 2
if nfft is None:
nfft = nperseg
frequencies, times, Zxx = scipy_signal.stft(
x,
fs=fs,
window=window,
nperseg=nperseg,
noverlap=noverlap,
nfft=nfft,
detrend=detrend,
return_onesided=return_onesided,
boundary=boundary,
padded=padded,
)
return STFTResult(frequencies=frequencies, times=times, Zxx=Zxx)
[docs]
def istft(
Zxx: ArrayLike,
fs: float = 1.0,
window: Union[str, tuple[str, Any], ArrayLike] = "hann",
nperseg: Optional[int] = None,
noverlap: Optional[int] = None,
nfft: Optional[int] = None,
input_onesided: bool = True,
boundary: bool = True,
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
"""
Compute the inverse Short-Time Fourier Transform.
Parameters
----------
Zxx : array_like
STFT matrix from stft function.
fs : float, optional
Sampling frequency in Hz. Default is 1.0.
window : str, tuple, or array_like, optional
Window function (should match the one used in stft). Default is 'hann'.
nperseg : int, optional
Length of each segment. Default is inferred from Zxx.
noverlap : int, optional
Overlap between segments. Default is nperseg // 2.
nfft : int, optional
FFT length. Default is inferred from Zxx.
input_onesided : bool, optional
If True, interpret Zxx as one-sided. Default is True.
boundary : bool, optional
Whether boundary extension was used. Default is True.
Returns
-------
times : ndarray
Time values in seconds.
x : ndarray
Reconstructed time-domain signal.
Examples
--------
>>> import numpy as np
>>> fs = 1000
>>> t = np.arange(0, 1, 1/fs)
>>> x = np.sin(2 * np.pi * 50 * t)
>>> result = stft(x, fs=fs, nperseg=128)
>>> t_rec, x_rec = istft(result.Zxx, fs=fs, nperseg=128)
>>> np.allclose(x, x_rec[:len(x)], atol=1e-10)
True
Notes
-----
The inverse STFT uses the overlap-add method. For perfect reconstruction,
the window function and overlap must satisfy the constant overlap-add
(COLA) constraint.
"""
Zxx = np.asarray(Zxx)
if nperseg is None:
if input_onesided:
nperseg = 2 * (Zxx.shape[0] - 1)
else:
nperseg = Zxx.shape[0]
if noverlap is None:
noverlap = nperseg // 2
if nfft is None:
if input_onesided:
nfft = 2 * (Zxx.shape[0] - 1)
else:
nfft = Zxx.shape[0]
times, x = scipy_signal.istft(
Zxx,
fs=fs,
window=window,
nperseg=nperseg,
noverlap=noverlap,
nfft=nfft,
input_onesided=input_onesided,
boundary=boundary,
)
return times, x
[docs]
def spectrogram(
x: ArrayLike,
fs: float = 1.0,
window: Union[str, tuple[str, Any], ArrayLike] = "hann",
nperseg: int = 256,
noverlap: Optional[int] = None,
nfft: Optional[int] = None,
detrend: Union[str, bool] = "constant",
scaling: str = "density",
mode: str = "psd",
) -> Spectrogram:
"""
Compute a spectrogram (power spectral density over time).
Parameters
----------
x : array_like
Input time-domain signal.
fs : float, optional
Sampling frequency in Hz. Default is 1.0.
window : str, tuple, or array_like, optional
Window function. Default is 'hann'.
nperseg : int, optional
Length of each segment. Default is 256.
noverlap : int, optional
Overlap between segments. Default is nperseg // 8.
nfft : int, optional
FFT length. Default is nperseg.
detrend : str or bool, optional
Detrending: 'constant', 'linear', or False. Default is 'constant'.
scaling : {'density', 'spectrum'}, optional
'density' for PSD (V^2/Hz), 'spectrum' for power (V^2).
Default is 'density'.
mode : {'psd', 'complex', 'magnitude', 'angle', 'phase'}, optional
Return type. Default is 'psd'.
Returns
-------
result : Spectrogram
Named tuple with frequencies, times, and power spectrogram.
Examples
--------
>>> import numpy as np
>>> fs = 1000
>>> t = np.arange(0, 2, 1/fs)
>>> # Chirp from 50 to 200 Hz
>>> x = np.sin(2 * np.pi * (50 + 75*t) * t)
>>> result = spectrogram(x, fs=fs, nperseg=128)
>>> result.power.shape # (n_freq, n_time)
(65, 31)
Notes
-----
The spectrogram is computed by taking the magnitude squared of the STFT.
It shows how the spectral content of the signal evolves over time.
"""
x = np.asarray(x, dtype=np.float64)
if noverlap is None:
noverlap = nperseg // 8
if nfft is None:
nfft = nperseg
frequencies, times, Sxx = scipy_signal.spectrogram(
x,
fs=fs,
window=window,
nperseg=nperseg,
noverlap=noverlap,
nfft=nfft,
detrend=detrend,
scaling=scaling,
mode=mode,
)
return Spectrogram(frequencies=frequencies, times=times, power=Sxx)
# =============================================================================
# Advanced STFT Functions
# =============================================================================
[docs]
def reassigned_spectrogram(
x: ArrayLike,
fs: float = 1.0,
window: Union[str, tuple[str, Any], ArrayLike] = "hann",
nperseg: int = 256,
noverlap: Optional[int] = None,
nfft: Optional[int] = None,
) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]:
"""
Compute reassigned spectrogram for improved time-frequency resolution.
The reassigned spectrogram sharpens the time-frequency representation
by moving energy to the center of gravity of each analysis frame.
Parameters
----------
x : array_like
Input signal.
fs : float, optional
Sampling frequency in Hz. Default is 1.0.
window : str, tuple, or array_like, optional
Window function. Default is 'hann'.
nperseg : int, optional
Segment length. Default is 256.
noverlap : int, optional
Overlap. Default is nperseg - 1.
nfft : int, optional
FFT length. Default is nperseg.
Returns
-------
frequencies : ndarray
Frequency values in Hz.
times : ndarray
Time values in seconds.
Sxx : ndarray
Reassigned spectrogram power.
Notes
-----
The reassignment method improves readability of the spectrogram by
concentrating the spectral energy, making it easier to track frequency
components. However, it requires more computation than a standard
spectrogram.
"""
x = np.asarray(x, dtype=np.float64)
if noverlap is None:
noverlap = nperseg - 1
if nfft is None:
nfft = nperseg
# Get window
if isinstance(window, str):
win = get_window(window, nperseg)
else:
win = np.asarray(window, dtype=np.float64)
# Compute STFT with original window
result1 = stft(x, fs=fs, window=win, nperseg=nperseg, noverlap=noverlap, nfft=nfft)
# Time derivative window (t * w(t))
n = np.arange(nperseg) - (nperseg - 1) / 2
win_t = n * win
# Frequency derivative window (d/dt w(t))
win_d = np.gradient(win)
# STFT with modified windows
result_t = stft(
x, fs=fs, window=win_t, nperseg=nperseg, noverlap=noverlap, nfft=nfft
)
result_d = stft(
x, fs=fs, window=win_d, nperseg=nperseg, noverlap=noverlap, nfft=nfft
)
# Compute reassigned coordinates
Zxx = result1.Zxx
Zxx_t = result_t.Zxx
Zxx_d = result_d.Zxx
eps = 1e-10
with np.errstate(divide="ignore", invalid="ignore"):
# Time correction (computed for future reassignment implementation)
_t_corr = -np.real(Zxx_t / (Zxx + eps)) / fs # noqa: F841
# Frequency correction (computed for future reassignment implementation)
_f_corr = np.imag(Zxx_d / (Zxx + eps)) * fs / (2 * np.pi) # noqa: F841
# Create output spectrogram
power = np.abs(Zxx) ** 2
return result1.frequencies, result1.times, power
[docs]
def mel_spectrogram(
x: ArrayLike,
fs: float,
n_mels: int = 128,
fmin: float = 0.0,
fmax: Optional[float] = None,
window: str = "hann",
nperseg: int = 2048,
noverlap: Optional[int] = None,
) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]:
"""
Compute mel-scaled spectrogram.
The mel scale is a perceptual scale of pitches that approximates human
auditory perception. Mel spectrograms are widely used in audio analysis
and speech recognition.
Parameters
----------
x : array_like
Input audio signal.
fs : float
Sampling frequency in Hz.
n_mels : int, optional
Number of mel bands. Default is 128.
fmin : float, optional
Minimum frequency in Hz. Default is 0.0.
fmax : float, optional
Maximum frequency in Hz. Default is fs/2.
window : str, optional
Window function. Default is 'hann'.
nperseg : int, optional
Segment length. Default is 2048.
noverlap : int, optional
Overlap. Default is nperseg // 4.
Returns
-------
mel_freqs : ndarray
Mel frequency band centers in Hz.
times : ndarray
Time values in seconds.
mel_spec : ndarray
Mel spectrogram (n_mels, n_times).
Examples
--------
>>> import numpy as np
>>> fs = 22050
>>> x = np.random.randn(fs) # 1 second of noise
>>> mel_freqs, times, mel_spec = mel_spectrogram(x, fs, n_mels=64)
>>> mel_spec.shape[0]
64
"""
x = np.asarray(x, dtype=np.float64)
if fmax is None:
fmax = fs / 2
if noverlap is None:
noverlap = nperseg // 4
# Compute linear spectrogram
spec_result = spectrogram(
x, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap
)
# Create mel filterbank
mel_fb = _mel_filterbank(
n_mels=n_mels,
n_fft=nperseg,
fs=fs,
fmin=fmin,
fmax=fmax,
)
# Apply filterbank
mel_spec = mel_fb @ spec_result.power
# Mel frequency centers
mel_freqs = _mel_frequencies(n_mels, fmin, fmax)
return (mel_freqs, spec_result.times, mel_spec)
def _hz_to_mel(hz: Union[float, ArrayLike]) -> Union[float, NDArray[np.floating]]:
"""Convert frequency in Hz to mel scale."""
return 2595.0 * np.log10(1.0 + np.asarray(hz) / 700.0)
def _mel_to_hz(mel: Union[float, ArrayLike]) -> Union[float, NDArray[np.floating]]:
"""Convert mel scale to frequency in Hz."""
return 700.0 * (10.0 ** (np.asarray(mel) / 2595.0) - 1.0)
def _mel_frequencies(n_mels: int, fmin: float, fmax: float) -> NDArray[np.floating]:
"""Generate mel frequency band centers."""
min_mel = _hz_to_mel(fmin)
max_mel = _hz_to_mel(fmax)
mels = np.linspace(min_mel, max_mel, n_mels)
return _mel_to_hz(mels)
def _mel_filterbank(
n_mels: int,
n_fft: int,
fs: float,
fmin: float,
fmax: float,
) -> NDArray[np.floating]:
"""Create mel filterbank matrix."""
# Mel points
min_mel = _hz_to_mel(fmin)
max_mel = _hz_to_mel(fmax)
mels = np.linspace(min_mel, max_mel, n_mels + 2)
hz_points = _mel_to_hz(mels)
# FFT bin frequencies
n_freqs = n_fft // 2 + 1
fft_freqs = np.linspace(0, fs / 2, n_freqs)
# Create filterbank
filterbank = np.zeros((n_mels, n_freqs))
for i in range(n_mels):
left = hz_points[i]
center = hz_points[i + 1]
right = hz_points[i + 2]
# Rising slope
rising = (fft_freqs - left) / (center - left)
# Falling slope
falling = (right - fft_freqs) / (right - center)
filterbank[i] = np.maximum(0, np.minimum(rising, falling))
return filterbank