Source code for gwadama.fat

"""fat.py

Frequency analysis toolkit.

"""
import warnings

from gwpy.frequencyseries import FrequencySeries
from gwpy.timeseries import TimeSeries
import numpy as np
import scipy as sp



[docs] def whiten(strain: np.ndarray, *, asd: np.ndarray, sample_rate: int, flength: int, highpass: float = None, pad: int = 0, unpad: int = 0, normed: bool = True, **kwargs) -> np.ndarray: """Whiten a single strain signal. Whiten a strain using the input amplitude spectral density 'asd', and shrinking signals afterwarwds to 'l_window' to account for the edge effects introduced by the windowing. Parameters ---------- strain : NDArray Strain data points in time domain. asd : 2d-array Amplitude spectral density assumed for the 'set_strain'. Its components are: - asd[0] = frequency points - asd[1] = ASD points NOTE: It must has a linear and constant sampling frequency! sample_rate : int The thingy that makes things do correctly their thing. flength : int Length (in samples) of the time-domain FIR whitening filter. Passed in seconds (`flength/sample_rate`) to GWpy's whiten() function as the 'fduration' parameter. pad : int, optional Marging at each side of the strain to add (zero-pad) in order to avoid edge effects. The corrupted area at each side is `0.5 * fduration` in GWpy's whiten(). Will be cropped afterwards, thus no samples are added at the end of the call to this function. If given, 'unpad' will be ignored. unpad : int, optional Marging at each side of the strain to crop. Will be ignored if 'pad' is given. highpass : float, optional Highpass corner frequency (in Hz) of the FIR whitening filter. normed : bool If True, normalizes the strains to their maximum absolute amplitude. **kwargs: Extra arguments passed to gwpy.timeseries.Timeseries.whiten(). Returns ------- strain_w : NDArray Whitened strain (in time domain). """ if asd.ndim != 2: raise ValueError("'asd' must have 2 dimensions") if not isinstance(flength, int): raise TypeError("'flength' must be an integer") _asd = FrequencySeries(asd[1], frequencies=asd[0]) if pad > 0: strain = np.pad(strain, pad, 'constant', constant_values=0) _unpad = slice(pad, -pad) elif unpad == 0: _unpad = slice(None) else: _unpad = slice(unpad, -unpad) frame = TimeSeries(strain, sample_rate=sample_rate) strain_w = frame.whiten( asd=_asd, fduration=flength/sample_rate, # to seconds highpass=highpass, **kwargs ).value # Convert to numpy array!!! strain_w = strain_w[_unpad] if normed: strain_w /= np.max(np.abs(strain_w)) return strain_w
# def highpass_filter(signal: np.ndarray, # f_cut: int | float, # f_width: int | float, # sample_rate: int) -> np.ndarray: # """Apply a forward-backward digital highpass filter. # Apply a forward-backward digital highpass filter to 'signal' CENTERED # at frequency 'f_cut' with a transition band of 'f_width'. # It enforces the (single) filter to allow only loss of 2 dB at passband # (`f_cut + f_width/2` Hz) and a minimum filter of 20 dB at stopband # (`f_cut - f_width/2` Hz). # REFERENCES # ---------- # Order selection: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.buttord.html # Design: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html # Filter: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfiltfilt.html # """ # f_pass = f_cut + f_width/2 # f_stop = f_cut - f_width/2 # N, wn = sp.signal.buttord(wp=f_pass, ws=f_stop, gpass=2, gstop=16, fs=self.sample_rate) # sos = sp.signal.butter(N, wn, btype='highpass', fs=sample_rate, output='sos') # filtered = sp.signal.sosfiltfilt(sos, signal) # return filtered
[docs] def highpass_filter(signal: np.ndarray, *, f_cut: int | float, f_order: int | float, sample_rate: int) -> np.ndarray: """Apply a forward-backward digital highpass filter. Apply a forward-backward digital highpass filter to 'signal' at frequency 'f_cut' with an order of 'f_order'. Reference --------- Design: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.butter.html Filter: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfiltfilt.html """ sos = sp.signal.butter(f_order, f_cut, btype='highpass', fs=sample_rate, output='sos') filtered = sp.signal.sosfiltfilt(sos, signal) return filtered
[docs] def instant_frequency(signal, *, sample_rate, phase_corrections=None): """Computes the instantaneous frequency of a time-domain signal. Computes the instantaneous frequency of a time-domain signal using the central difference method, with optional phase corrections. If negative frequencies are detected, a warning is raised. Parameters ---------- signal : ndarray The input time-domain signal. sample_rate : float The sampling rate of the signal (in Hz). phase_corrections : list of tuples, optional A list of phase corrections. Each tuple contains: (jump_start, jump_end, correction_factor). If None, no phase correction is applied. Returns ------- inst_freq : ndarray The instantaneous frequency of the signal (in Hz). If negative frequencies are detected, a RuntimeWarning is issued. """ # Step 1: Compute the analytic signal using the Hilbert transform analytic_signal = sp.signal.hilbert(signal) # Extract the instantaneous phase inst_phase = np.unwrap(np.angle(analytic_signal)) # Step 2: Apply multiple phase corrections if provided if phase_corrections is not None: # Get the time array corresponding to the signal length time = np.arange(len(signal)) / sample_rate for (jump_start, jump_end, correction_factor) in phase_corrections: inst_phase = correct_phase(inst_phase, np.arange(len(signal)) / sample_rate, jump_start, jump_end, correction_factor) # Step 3: Compute the instantaneous frequency by differentiating the phase dt = 1.0 / sample_rate inst_phase_diff = (inst_phase[2:] - inst_phase[:-2]) / (2 * dt) # Convert phase difference to frequency inst_freq = inst_phase_diff / (2.0 * np.pi) # Pad the result to match the input length inst_freq = np.pad(inst_freq, (1, 1), mode='edge') if np.any(inst_freq < 0): warnings.warn("Non-physical negative frequencies detected in the array.", RuntimeWarning) return inst_freq
[docs] def correct_phase(phase, time, jump_start, jump_end, correction_factor=1.0): """Manually correct a phase jump. Fine-tunes the manual phase correction by adjusting the phase after a phase jump. The phase after the jump is scaled by the correction factor. Parameters ---------- phase : ndarray The unwrapped phase of the signal. time : ndarray The time array corresponding to the signal. jump_start : float The time where the phase jump starts. jump_end : float The time where the phase jump ends. correction_factor : float, optional The factor to scale the phase correction for fine-tuning. Default is 1.0. Returns ------- corrected_phase : ndarray The phase with fine-tuned manual correction applied after the specified jump. """ corrected_phase = np.copy(phase) # Identify the indices corresponding to the jump start and end start_idx = np.argmin(np.abs(time - jump_start)) end_idx = np.argmin(np.abs(time - jump_end)) # Calculate the phase difference between start and end of the jump phase_diff = corrected_phase[end_idx] - corrected_phase[start_idx] # Adjust the phase after the jump by scaling the correction factor corrected_phase[end_idx:] -= phase_diff * correction_factor return corrected_phase
[docs] def snr(strain, *, psd, at, window=('tukey',0.5)): """Signal to Noise Ratio.""" # rFFT strain = np.asarray(strain) ns = len(strain) if isinstance(window, tuple): window = sp.signal.windows.get_window(window, ns) else: window = np.asarray(window) hh = np.fft.rfft(strain * window) ff = np.fft.rfftfreq(ns, d=at) af = ff[1] # Lowest and highest frequency cut-off taken from the given psd f_min, f_max = psd[0][[0,-1]] i_min = np.argmin(ff < f_min) i_max = np.argmin(ff < f_max) if i_max == 0: i_max = len(hh) hh = hh[i_min:i_max] ff = ff[i_min:i_max] # SNR psd_interp = sp.interpolate.interp1d(*psd, bounds_error=True)(ff) sum_ = np.sum(np.abs(hh)**2 / psd_interp) snr = np.sqrt(4 * at**2 * af * sum_) return snr