"""plotting.py
Custom plotting functions
"""
import warnings
import matplotlib as mpl
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
import scipy as sp
[docs]
def plot_spectrogram_with_instantaneous_features(
strain_array,
time_array,
fs=2**14,
outseg=None,
outfreq=None,
window=None,
hop=32,
mfft=None,
vmin=None,
vmax=None,
spec_log=True,
spec_norm=True,
spec_interpol='lanczos',
if_line_width=2
) -> tuple[Figure, tuple[Axes,Axes,Axes], NDArray]:
"""Plot the spectrogram, instantaneous frequency, and strain's waveform.
This function generates a multi-panel plot consisting of:
1. A spectrogram of the gravitational wave strain, obtained using
Short-Time Fourier Transform (STFT), visualizing the frequency evolution
over time.
2. The instantaneous frequency of the strain, plotted on top of the
spectrogram to show the frequency changes in real time.
3. The raw gravitational wave strain in the time domain, shown above the
spectrogram for direct comparison.
Key features of the plot:
- **Spectrogram**: The frequency content of the gravitational wave signal
is displayed over time using a color map (`inferno`), with the x-axis
representing time (in milliseconds) and the y-axis representing frequency
(in Hz).
- **Instantaneous Frequency**: Plots the instantaneous frequency of the
strain over time, highlighting the frequency variations.
- **Energy Normalization**: The spectrogram uses a logarithmic scale for
the energy (power spectral density, PSD), optionally normalized by the
maximum energy value in the signal.
- **Dynamic Range Control**: The color scale of the spectrogram can be
adjusted via the `vmin` parameter to emphasize specific energy levels.
- **Time-Domain Waveform**: A plot of the original strain data in the time
domain is shown above the spectrogram, providing context for the signal's
evolution.
- **Segmentation**: The user can specify the time (`outseg`) and frequency
(`outfreq`) ranges to focus on specific parts of the data.
- **Customization**: The plot has a black background, white grid lines, and
labeled colorbars for clarity.
Parameters
----------
strain_array : numpy.ndarray
The time-domain strain data of the gravitational wave signal.
time_array : numpy.ndarray
Array of time stamps corresponding to the strain data.
fs : int, optional
The sampling frequency of the data in Hz (default is 2^14, or 16384
Hz).
outseg : tuple, optional
A tuple specifying the time range (start, end) in seconds for the
x-axis. If `None`, the entire time range of the input data is used.
outfreq : tuple, optional
A tuple specifying the frequency range (start, end) in Hz for the
y-axis. If `None`, the full frequency range (up to Nyquist frequency)
is used.
window : numpy.ndarray, optional
The window function applied during STFT computation (default is a Tukey
window).
hop : int, optional
The hop size between successive STFT windows (default is 32).
mfft : int, optional
The number of points in the FFT used for STFT computation (default is
None).
vmin, vmax : float, optional
The minimum/maximum value for the color scale in the spectrogram. This
controls the dynamic range of the color map.
spec_log : bool, optional
If true, represent `np.log10(Sxx)`.
spec_norm : bool, optional
If true, normalize the spectrogram to the maximum energy value.
spec_interpol : str, optional
Interpolation used for visual representation (default 'Lanczos'). See
`matplotlib.pyplot.imshow` for other options.
if_line_width : int | float, optional
The line width of the instantaneous frequency plot (default is 2).
Returns
-------
fig : matplotlib.figure.Figure
The figure object containing the complete plot.
axs : List[matplotlib.axes.Axes]
A list of axes objects containing the spectrogram, the colorbar, and
the time-domain plots.
Sxx : numpy.ndarray
The computed spectrogram (PSD values) of the input strain data.
Notes
-----
- The y-axis of the spectrogram uses a kilohertz scale for readability.
- The time-domain waveform is plotted without axes labels for simplicity.
- Instantaneous frequency values below zero are masked to avoid displaying
non-physical results.
"""
from gwadama.fat import instant_frequency
if window is None:
window = sp.signal.windows.tukey(128,0.5)
# Compute the spectrogram using the ShortTimeFFT class.
stfft_model = sp.signal.ShortTimeFFT(
win=window, hop=hop, fs=fs, mfft=mfft,
fft_mode='onesided', scale_to='psd'
)
Sxx = stfft_model.spectrogram(strain_array)
# Optional scaling and normalisation.
if spec_log:
with np.errstate(divide='ignore', invalid='ignore'):
_Sxx = np.log10(Sxx)
finite_mask = np.isfinite(_Sxx)
if not finite_mask.any():
raise ValueError("Spectrogram contains no finite values after log10.")
if spec_norm:
_Sxx -= np.nanmax(_Sxx)
else:
_Sxx = Sxx.astype(float)
if spec_norm:
_Sxx = _Sxx / np.nanmax(_Sxx)
_Sxx = np.ma.masked_invalid(_Sxx)
# If the user didn’t provide limits, derive safe defaults from finite data
if vmin is None:
vmin = float(np.nanmin(_Sxx))
if vmax is None:
vmax = float(np.nanmax(_Sxx))
if not (np.isfinite(vmin) and np.isfinite(vmax)) or vmin >= vmax:
raise ValueError("Invalid colour limits: ensure finite vmin < vmax.")
# Make the mapping explicit and clipped
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True) # type: ignore
# Time-frequency grid
t0, t1, f0, f1 = stfft_model.extent(len(strain_array))
t_origin = time_array[0]
t0 += t_origin
t1 += t_origin
fig = plt.figure(figsize=(10, 6))
# Define a grid with 5 rows: top waveform (1), gap (1), spectrogram (3)
gs = GridSpec(
nrows=4, ncols=2,
width_ratios=[60, 1], # Main plot vs narrow colorbar
height_ratios=[1, 0.05, 6, 0.05], # Top waveform, small gap, spectrogram
hspace=0.05, wspace=0.02
)
# Axes
ax3 = fig.add_subplot(gs[0, 0]) # Top waveform
ax = fig.add_subplot(gs[2, 0], sharex=ax3) # Spectrogram
ax2 = fig.add_subplot(gs[2, 1]) # Colorbar
# SPECTROGRAM (ax1)
im = ax.imshow(
_Sxx, extent=(t0,t1,f0,f1),
origin='lower', aspect='auto', cmap='inferno',
interpolation=spec_interpol, norm=norm
)
# Background in black to match the colormap.
ax.set_facecolor('black')
# ...and Instant Frequency
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
instant_freq = instant_frequency(strain_array, fs=fs)
length = len(strain_array)
t1_instant = t_origin + (length-1)/fs
instant_time = np.linspace(t_origin, t1_instant, length)
mask = instant_freq >= 0 # Remove non-physical frequencies
instant_freq = instant_freq[mask]
instant_time = instant_time[mask]
ax.plot(instant_time, instant_freq, 'purple', lw=if_line_width)
# COLOURBAR (ax2)
cbar = fig.colorbar(
im, cax=ax2,
boundaries=np.linspace(vmin, vmax, 256),
ticks=np.linspace(vmin, vmax, 6),
extend='both'
)
# LABELS, LIMITS, ETC
ax.grid(True, ls='--', alpha=.4)
# ...limits
if outseg is None:
ax.set_xlim(time_array[0], time_array[-1])
else:
ax.set_xlim(*outseg)
if outfreq is None:
ax.set_ylim(0, fs/2)
else:
ax.set_ylim(*outfreq)
# ...X ticks to milliseconds
ax.xaxis.set_major_formatter(
FuncFormatter(lambda x, _: f"{x*1e3:.1f}") # seconds → milliseconds
)
# ...labels.
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Frequency [Hz]')
match (spec_log, spec_norm):
case True, True:
cbar.set_label(r"Norm. $\log_{10}\, \mathrm{PSD}\;[\mathrm{strain}^2/\mathrm{Hz}]$")
case True, False:
cbar.set_label(r"$\log_{10}\, \mathrm{PSD}\;[\mathrm{strain}^2/\mathrm{Hz}]$")
case False, True:
cbar.set_label(r"Norm. $\mathrm{PSD}\;[\mathrm{strain}^2/\mathrm{Hz}]$")
case False, False:
cbar.set_label(r"$\mathrm{PSD}\;[\mathrm{strain}^2/\mathrm{Hz}]$")
# GW IN TIME-DOMAIN ON TOP OF THE SPECTROGRAM (ax3)
ax3.plot(time_array, strain_array, c='black', lw=1, alpha=1)
ax3.set_xlim(ax.get_xlim())
ax3.set_ylim(np.min(strain_array), np.max(strain_array))
ax3.axis('off')
fig.subplots_adjust(left=0.08, right=0.91, top=0.96, bottom=0.08)
return fig, (ax, ax2, ax3), Sxx