Source code for gwadama.tat

"""tat.py

Time analysis toolkit.

"""


import numpy as np
import scipy as sp
import scipy.signal
from scipy.interpolate import make_interp_spline as sp_make_interp_spline



[docs] def resample(strain: np.ndarray, time: np.ndarray | int, sample_rate: int, full_output=True) -> tuple[np.ndarray, int, int]: """Resample a single strain in time domain. Resample strain's sampling rate using an interpolation in the time domain for upscalling to a constant rate, and then decimate it to the target rate. The upscaled sample rate is chosen as the minimum common multiple between the next integer value of the maximum sampling rate found in the original strain, and the target sample rate. PARAMETERS ---------- strain: 1d-array Only one strain. time: 1d-array | int | float Time points. If an Int or Float is given, it is interpreted as the former sampling rate, and assumed to be constant. sample_rate: int Target sample rate. NOTE: It cannot be fractional. full_output: bool, optional If True, also returns the new time points, the upscaled sampling rate, and the factor down. RETURNS ------- strain: 1d-array Strain at the new sampling rate. time: 1d-array, optional New time points. sr_up: int, optional Upscaled sample rate. factor_down: int, optional Factor at which the signal is decimated after the upscalling. """ if isinstance(time, np.ndarray): sr_max = 1 / np.min(np.diff(time)) elif isinstance(time, int): sr_max = time t1 = (len(strain) - 1) / sr_max time = gen_time_array(0, t1, sr_max) else: raise TypeError("'time' type not recognized") # Upsample: # sr_up = int((sr_max // sample_rate + 1) * sample_rate) # Intentionally skipping last time point to avoid extrapolation by round-off errors. time_up = np.arange(time[0], time[-1], 1/sr_up) strain = sp_make_interp_spline(time, strain, k=2)(time_up) # len(strain) = len(strain) - 1 time = time_up # Downsample (if needed): # factor_down = sr_up // sample_rate if factor_down > 1: time = time[::factor_down] strain = sp.signal.decimate(strain, factor_down, ftype='fir') elif factor_down < 1: raise RuntimeError(f"factor_down = {factor_down} < 1") return strain, time, sr_up, factor_down if full_output else strain
[docs] def gen_time_array(t0, t1, sr): """Generate a time array with constant sampling rate. Extension of numpy.arange which takes care of the case when an extra sample is produced due to round-off errors. When this happens, the extra sample is cut off. Parameters ---------- t0, t1: float Initial and final times of the array: [t0, t1). sr: int Sample rate. length: int Length of the final time array in samples. If due to round-off errors the length of the array is longer, it will be adjusted. Returns ------- times: NDArray Time array. """ times = np.arange(t0, t1, 1/sr) if times[-1] >= t1: times = times[:-1] return times
[docs] def pad_time_array(times: np.ndarray, pad: int | tuple) -> np.ndarray: """Extend a time array by 'pad' number of samples. Parameters ---------- times: NDArray Time array. pad: int | tuple If int, number of samples to add on both sides. If tuple, number of samples to add on each side. Returns ------- NDArray Padded time array. NOTES ----- - Computes again the entire time array. - Due to round-off errors some intermediate time values might be slightly different. """ if isinstance(pad, int): pad0, pad1 = pad, pad elif isinstance(pad, tuple): pad0, pad1 = pad else: raise TypeError("'pad' type not recognized") length = len(times) + pad0 + pad1 dt = times[1] - times[0] t0 = times[0] - pad0*dt t1 = t0 + (length-1)*dt return np.linspace(t0, t1, length)
[docs] def shrink_time_array(times: np.ndarray, unpad: int) -> np.ndarray: """Shrink a time array on both sides by 'unpad' number of samples. NOTES ----- - Computes again the entire time array. - Due to round-off errors some intermediate time values might be slightly different. """ l = len(times) - 2*unpad dt = times[1] - times[0] t0 = times[0] + unpad*dt t1 = t0 + (l-1)*dt return np.linspace(t0, t1, l)
[docs] def find_time_origin(times: np.ndarray) -> int: """Find the index position of the origin of a time array. It is just a shortcut for `np.argmin(np.abs(times))`. Parameters ---------- times : NDArray Time array. Returns ------- _ : int Index position of the time origin (0). """ return np.argmin(np.abs(times))
[docs] def find_merger(h: np.ndarray) -> int: """Estimate the index position of the merger in the given strain. This function provides a rough estimate of the merger index position by locating the maximum of the absolute amplitude of the gravitational wave signal in the time domain. It assumes that the merger roughly corresponds to this peak, which holds for certain clean or high-SNR simulated CBC gravitational waves. :warning: This function may be replaced in the near future by a more formal estimator with a better model, such as a Gaussian fit for binary mergers. :caution: This is a very ad-hoc method and may not be accurate for all datasets, especially depending on the sensitivity of the detector. This method assumes that the peak amplitude in the time domain corresponds closely to the merger, which may not hold for lower-SNR signals or noisy data. Parameters ---------- h : np.ndarray The gravitational wave strain data. Returns ------- int The index of the estimated merger position in the strain data. """ return np.argmax(np.abs(h))