"""Ad-hoc plotting functions for visualizing results.
This module includes a variety of plotting utilities designed to help visualize
and interpret results during the development and debugging of the CLAWDIA
pipeline.
While not essential for CLAWDIA's core processing, these functions are useful
for presenting and analyzing outcomes, such as confusion matrices, dictionary
atoms, and spectrograms.
"""
from colorsys import rgb_to_hls, hls_to_rgb
import itertools as it
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
[docs]
def plot_confusion(cmat, ax=None, labels=None, mode='both', vmin=None, vmax=None,
cmap="PaleBlues", **kwargs):
"""Plot a confusion matrix.
Plot a pre-computed confusion matrix `cmat`.
Rows must contain true values, and columns predicted values. For example,
in a binary classification case:
+---------+---------+---------+
| T \ P | Pred C1 | Pred C2 |
+=========+=========+=========+
| True C1 | TP | FN |
+---------+---------+---------+
| True C2 | FP | TN |
+---------+---------+---------+
Where:
- **TP** (True Positives): Correctly predicted Class 1.
- **FN** (False Negatives): Class 1 incorrectly predicted as Class 2.
- **FP** (False Positives): Class 2 incorrectly predicted as Class 1.
- **TN** (True Negatives): Correctly predicted Class 2.
Parameters
----------
cmat : array-like of shape (n_classes, n_classes)
The confusion matrix to plot.
ax : matplotlib.axes.Axes, optional
The axes on which to plot the matrix. If not given, a new figure and axes
are created.
labels : list of str, optional
The labels for the classes. If not given, the integers from 0 to
`n_classes-1` are used.
mode : {'absolute', 'percent', 'both'}
The format of the annotations: absolute numbers, percentages, or both.
Defaults to 'both'.
vmin : float, optional
The minimum value of the color scale.
vmax : float, optional
The maximum value of the color scale.
cmap : Function (matplotlib's Colormap equivalent) | str
Defaults to "PaleBlue", a custom modification of Matplotlib's "Blues".
**kwargs
Additional keyword arguments are passed to `matplotlib.pyplot.subplots`.
Returns
-------
fig : matplotlib.figure.Figure or None
The figure containing the plot, or None if `ax` was given.
"""
with np.errstate(divide='ignore', invalid='ignore'):
cmat_perc = np.nan_to_num(cmat / np.sum(cmat, axis=1, keepdims=True))
if mode == 'both':
if plt.rcParams["text.usetex"]:
format_str = lambda *args: '{}\n{:.0%}'.format(*args).replace('%', r'\,\%')
else:
format_str = lambda *args: '{}\n{:.0%}'.format(*args)
elif mode == 'percent':
if plt.rcParams["text.usetex"]:
format_str = lambda *args: '{:.0%}'.format(args[1]).replace('%', r'\,\%')
else:
format_str = lambda *args: '{:.0%}'.format(args[1])
elif mode == 'absolute':
format_str = lambda *args: str(args[0])
else:
raise ValueError("mode can only be 'absolute', 'percent' or 'both'")
if not callable(cmap):
if cmap == "PaleBlues":
cmap = _desaturate_cmap(plt.cm.Blues, desaturation_factor=0.8, brightness_boost=0.2)
elif isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
else:
raise TypeError("Colormap not recognized")
n_classes = len(cmat)
cmat_str = np.empty_like(cmat_perc, dtype=object)
for i_true, i_pred in it.product(range(n_classes), repeat=2):
cmat_str[i_true, i_pred] = format_str(cmat[i_true, i_pred], cmat_perc[i_true, i_pred])
if ax is None:
fig, ax = plt.subplots(**kwargs)
ax.imshow(cmat_perc, cmap=cmap, vmin=vmin, vmax=vmax)
for i_true, i_pred in it.product(range(n_classes), repeat=2):
ax.annotate(
cmat_str[i_true, i_pred],
xy=(i_pred, i_true),
ha='center',
va='center'
)
ax.grid(False)
# Change the frame (axes spines) color to gray
for spine in ax.spines.values():
spine.set_edgecolor('gray')
# Set X and Y labels
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
# Set ticks
ax.set_xticks(range(n_classes))
ax.set_yticks(range(n_classes))
ax.set_xlim([-0.5, n_classes-0.5])
ax.set_ylim([n_classes-0.5, -0.5]) # Invert Y-axis
if labels is not None:
ax.set_xticklabels(labels, rotation=45)
ax.set_yticklabels(labels, rotation=45)
# Move X-axis ticks and label to the top
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
try:
return fig
except NameError:
return None
[docs]
def plot_dictionary(array, c=None, ylim=None, **plot_kw):
"""Plot atoms from a dictionary in a squared matrix.
Parameters
----------
array : 2d-array
Dictionary matrix with shape (a, l).
c : int, optional
Number of atoms at each side of the squared matrix of plots; the total
number of plotted atoms will be `c ** 2`.
If not given, it is computed as `int(np.sqrt(a))`.
**plot_kw : optional
Passed to pyplot.subplots().
"""
if c is None:
c = int(np.sqrt(array.shape[0]))
fig, axs = plt.subplots(ncols=c, nrows=c, **plot_kw)
for i in range(c**2):
ax = axs[i//c,i%c]
ax.plot(array[i], lw=1)
ax.set_ylim(ylim)
ax.axis('off')
fig.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)
return fig
def _desaturate_cmap(cmap, *, desaturation_factor, brightness_boost):
"""
Desaturate and brighten a colormap.
Parameters
----------
cmap : matplotlib.colors.Colormap
The original colormap to modify.
desaturation_factor : float
The factor by which to reduce the saturation (0 = grayscale, 1 = no change).
brightness_boost : float
The amount to increase the brightness (lightness) (0 = no change).
Returns
-------
LinearSegmentedColormap
The modified colormap with reduced saturation and increased brightness.
"""
colors = cmap(np.linspace(0, 1, 256)) # Extract RGB colors
modified_colors = []
for r, g, b, a in colors: # Process each RGBA color
h, l, s = rgb_to_hls(r, g, b) # Convert to HLS
s *= desaturation_factor # Reduce saturation
l = min(l + brightness_boost * (1 - l), 1.0) # Adjust brightness, ensure no overflow
r, g, b = hls_to_rgb(h, l, s) # Convert back to RGB
modified_colors.append((r, g, b, a)) # Add the alpha channel back
return mpl.colors.LinearSegmentedColormap.from_list("DesaturatedBrightenedBlues", modified_colors)