"""
Estimators and metrics for signal analysis and comparison.
This module provides a variety of functions to compute statistical and signal-processing
metrics, such as mean squared error, structural similarity index, overlaps,
signal-to-noise ratios, and others. While some functions are specifically designed for
gravitational-wave signal analysis, they can also be applied to broader signal-processing
contexts.
"""
import numpy as np
import scipy as sp
[docs]
def mse(x, y):
"""Mean Squared Error."""
return float(np.mean((x-y)**2) / len(x))
[docs]
def medse(x, y):
"""Median Squared Error."""
return float(np.median((x-y)**2))
[docs]
def ssim(x, y):
"""Structural Similarity Index Measure (SSIM).
Compute the Structural Similarity Index Measure (SSIM) between two
arrays, `x` and `y`. SSIM is a perceptual metric that quantifies the
similarity between two signals or images, accounting for luminance,
contrast, and structure [1]_, [2]_.
Reference values:
1 → Perfect similarity.
0 → No similarity.
-1 → Perfect anti-correlation.
Parameters
----------
x : array_like
Input signal or image. Must be of the same shape as `y`.
y : array_like
Input signal or image. Must be of the same shape as `x`.
Returns
-------
res : float
The Structural Similarity Index Measure between the signals `x` and `y`.
References
----------
.. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
Image quality assessment: From error visibility to structural similarity.
IEEE Transactions on Image Processing, 13(4), 600-612.
.. [2] https://en.wikipedia.org/wiki/Structural_similarity
"""
mux = x.mean()
muy = y.mean()
cov_mat = np.cov(x, y, ddof=0)
sx2 = cov_mat[0, 0]
sy2 = cov_mat[1, 1]
sxy = cov_mat[0, 1]
l_ = 1
c1 = (0.01*l_) ** 2
c2 = (0.03*l_) ** 2
res = float(
(2 * mux * muy + c1) * (2 * sxy + c2)
/ ((mux**2 + muy**2 + c1) * (sx2 + sy2 + c2))
)
return res
[docs]
def dssim(x, y):
"""Structural Dissimilarity.
Reference values:
0 → Perfect correlation.
½ → No correlation.
1 → Perfect anticorrelation.
"""
return (1 - ssim(x, y)) / 2
[docs]
def issim(x, y):
"""Inverse Structural SimilarityIndex Measure.
In this case:
- 1: Perfect anti-correlation.
- 0: No similarity.
- -1: Perfect similarity.
Useful as a loss function to perform minimization.
"""
return -ssim(x, y)
[docs]
def residual(x, y):
"""Norm of the difference between 'x' and 'y'."""
return np.linalg.norm(x - y)
[docs]
def softmax(x, axis=None):
"""Softmax probability distribution."""
coefs = np.exp(x)
return coefs / coefs.sum(axis=axis, keepdims=True)
[docs]
def inner_product_weighted(x, y, *, at, psd=None, window='hann'):
"""Compute the weighted inner product (x|y) between two signals.
Parameters
----------
x, y: ndarray
Signals to compare.
at: float
Sample time step.
psd: 2d-array, optional
PSD to weight the overlap, will be linearly interpolated to the right frequencies.
psd[0] = frequencies
psd[1] = psd samples
References
----------
[1]: Eq. 12, DOI: 10.48550/arxiv.2210.06194
"""
ns = len(x)
if ns != len(y):
raise ValueError("both 'x' and 'y' must be of the same length")
if not np.isrealobj(x):
raise ValueError(f"'x' cannot be complex")
if not np.isrealobj(y):
raise ValueError(f"'y' cannot be complex")
w_array = sp.signal.windows.get_window(window, ns)
# rFFT
hx = np.fft.rfft(x * w_array)
hy = np.fft.rfft(y * w_array)
ff = np.fft.rfftfreq(ns, d=at)
if psd is not None:
# Lowest and highest frequency cut-off taken from the given psd
f_min, f_max = psd[0][[0,-1]]
i_min = np.searchsorted(ff, f_min, side='left')
i_max = np.searchsorted(ff, f_max, side='right')
hx = hx[i_min:i_max]
hy = hy[i_min:i_max]
ff = ff[i_min:i_max]
af = ff[1]
# Compute (x|y)
if psd is None:
inner = 4 * af * np.sum(hx * hy.conj()).real
else:
psd_interp = sp.interpolate.interp1d(*psd, bounds_error=True)(ff)
inner = 4 * af * np.sum((hx * hy.conj()) / psd_interp).real
return inner
[docs]
def overlap(x, y, *, at=1, psd=None, window=('tukey', 0.5)):
"""Compute the Overlap between two signals:
O = (x|y) / sqrt((x|x) · (y|y))
Reference values:
1 → Perfect correlation.
0 → No correlation.
-1 → Perfect anticorrelation.
Parameters
----------
x, y: array
Signals to compare.
at: float
Sample time step. Leave as `at=1` if signals in whitened space.
psd: 2d-array, optional
PSD to weight the overlap, will be linearly interpolated to the right
frequencies.
psd[0] = frequencies
psd[1] = psd samples
References
----------
[1]: Badger C. et al., 2022 (10.48550/arxiv.2210.06194)
"""
x = np.asarray(x)
y = np.asarray(y)
inner = lambda a, b: inner_product_weighted(a, b, at=at, psd=psd, window=window)
with np.errstate(divide='ignore', invalid='ignore'):
overlap = inner(x, y) / np.sqrt(inner(x, x) * inner(y, y))
overlap = np.nan_to_num(overlap)
return float(overlap)
[docs]
def doverlap(x, y, *, at, psd=None, window=('tukey', 0.5)):
"""Compute the Overlap pseudo-distance.
Useful to use the overlap as loss function.
Reference values:
0 → Perfect correlation.
½ → No correlation.
1 → Perfect anticorrelation.
Parameters
----------
x, y: array
Signals to compare.
at: float
Sample time step.
psd: 2d-array, optional
PSD to weight the overlap, will be linearly interpolated to the right
frequencies.
psd[0] = frequencies
psd[1] = psd samples
"""
return (1 - overlap(x, y, at=at, psd=psd, window=window)) / 2
[docs]
def match(x, y, *, at=1, psd=None, window=('tukey', 0.5), return_lag=False):
"""Time/phase–maximised match between two (whitened) signals.
This computes the PSD-weighted, normalised inner product maximised over
a cyclic time shift (lag) and over phase (by taking the absolute value).
TODO: The values don't seem to match exactly PyCBC's `match` function. Check.
Parameters
----------
x, y : ndarray
Signals to compare (same length).
at : float
Sample time step (seconds).
psd : 2d-array, optional
If given, weights the frequency-domain inner product; linearly
interpolated to FFT frequencies. psd[0]=freqs, psd[1]=PSD samples.
If None, the signal is assumed to be whitened, and therefore `at` can
be left `at=1` since it cancels out.
window : str | tuple, optional
Any scipy.signal window spec; applied equally to x and y.
return_lag : bool, optional
If True, also return (lag_samples, lag_seconds) at which the maximum
match is attained (cyclic correlation notion).
Returns
-------
m : float
Match in [0, 1].
(k, tau) : tuple[int, float], optional
Index lag and time lag in seconds (only if return_lag=True).
"""
x = np.asarray(x)
y = np.asarray(y)
n = len(x)
if n != len(y):
raise ValueError("both 'x' and 'y' must be of the same length")
if not np.isrealobj(x) or not np.isrealobj(y):
raise ValueError("inputs must be real-valued")
# Window and FFT
w = sp.signal.windows.get_window(window, n)
Xf = np.fft.rfft(x * w)
Yf = np.fft.rfft(y * w)
ff = np.fft.rfftfreq(n, d=at)
# Optional band-limit from PSD
if psd is not None:
f_min, f_max = psd[0][[0, -1]]
i_min = np.searchsorted(ff, f_min, side='left')
i_max = np.searchsorted(ff, f_max, side='right')
else:
i_min, i_max = 0, Xf.size
if i_max - i_min <= 0:
return (0.0, (0, 0.0)) if return_lag else 0.0
df = ff[1] - ff[0] if ff.size > 1 else 1.0 / (n * at)
# Weighting (flat for whitened data)
if psd is None:
W_band = 1.0
else:
psd_interp = sp.interpolate.interp1d(*psd, bounds_error=True)(ff[i_min:i_max])
W_band = 1.0 / psd_interp
Xb = Xf[i_min:i_max]
Yb = Yf[i_min:i_max]
# Norms <x|x>, <y|y> (on the band)
nx = 4.0 * df * np.sum((Xb * Xb.conj() * W_band).real)
ny = 4.0 * df * np.sum((Yb * Yb.conj() * W_band).real)
denom = np.sqrt(nx * ny)
if not np.isfinite(denom) or denom == 0.0:
return (0.0, (0, 0.0)) if return_lag else 0.0
# Cross-spectrum on the band
Xcross_band = (Xb * Yb.conj()) * W_band
# ---- Correct one-sided scaling for irfft ----
Z = np.zeros(n // 2 + 1, dtype=np.complex128)
Z[i_min:i_max] = (4.0 * df / denom) * Xcross_band
# Halve all interior bins; keep DC (0) and Nyquist (if present) unhalved
weights = np.full_like(Z, 0.5, dtype=np.float64)
weights[0] = 1.0
if n % 2 == 0:
weights[-1] = 1.0
Z *= weights
# ---------------------------------------------
# Correlation vs cyclic lag (irfft has 1/n)
cc = np.fft.irfft(Z, n=n) * n
k = int(np.argmax(np.abs(cc)))
m = float(np.abs(cc[k]))
# tiny numerical guard
if 1.0 < m < 1.0 + 1e-10:
m = 1.0
if return_lag:
# map to signed lag
k_signed = k if k <= n // 2 else k - n
tau = k_signed * at
return m, (k_signed, tau)
return m
[docs]
def imatch(x, y, *, at=1, psd=None, window=('tukey', 0.5), return_lag=False):
"""Shorthand for `1 - match()`."""
return 1 - match(x, y, at=at, psd=psd, window=window, return_lag=return_lag)
[docs]
def snr(strain, *, psd, at, window=('tukey',0.5)):
"""Signal to Noise Ratio."""
# rFFT
strain = np.asarray(strain)
ns = len(strain)
if isinstance(window, tuple):
window = sp.signal.windows.get_window(window, ns)
else:
window = np.asarray(window)
hh = np.fft.rfft(strain * window)
ff = np.fft.rfftfreq(ns, d=at)
af = ff[1]
# Lowest and highest frequency cut-off taken from the given psd
f_min, f_max = psd[0][[0,-1]]
i_min = np.argmin(ff < f_min)
i_max = np.argmin(ff < f_max)
if i_max == 0:
i_max = len(hh)
hh = hh[i_min:i_max]
ff = ff[i_min:i_max]
# SNR
psd_interp = sp.interpolate.interp1d(*psd, bounds_error=True)(ff)
sum_ = np.sum(np.abs(hh)**2 / psd_interp)
snr = np.sqrt(4 * at**2 * af * sum_)
return snr
[docs]
def find_merger(h: np.ndarray) -> int:
"""Estimate the index position of the merger in the given strain.
This could be done with a better estimation model, like a gaussian in
the case of binary mergers. However for our current project this does not
make much difference.
"""
return np.argmax(np.abs(h))