Source code for clawdia._dictionary_spams

import time
import warnings

import numpy as np
from numpy.typing import NDArray
import scipy.optimize
import spams
from tqdm import tqdm

from . import estimators
from . import lib


# Remove warning from OpenMP, present in older versions of python-spams.
if not '__version__' in dir(spams) or spams.__version__ <= '2.6.5.4':
    import os
    os.environ['KMP_WARNINGS'] = 'FALSE'


[docs] class DictionarySpams: """Sparse Dictionary Learning (SDL) model for waveform denoising via SPAMS. This class provides an object-oriented implementation of a Sparse Dictionary Learning model, designed for the denoising and reconstruction of waveforms. At its core, it utilizes the `trainDL` function for dictionary learning and the `lasso` function for sparse coding from the SPAMS-python library [1]_. It extends these core functionalities to arbitrarily long signals and minibatch processing for large datasets. Additionally, the class includes various utilities for signal preprocessing, composite models of denoising (such as iterative reconstruction), and the ability to easily save and load the dictionary's state. Attributes ---------- dict_init : ndarray Atoms of the initial dictionary. Remains unaltered after training. components : ndarray Atoms of the current (trained) dictionary. model : tuple SPAMS' trainDL model components in the form (A, B, iter). d_size : int Number of atoms in the dictionary (dictionary size). a_length : int Length of each atom in the dictionary (patch size). lambda1 : float Regularization parameter for training the dictionary. batch_size : int Batch size used in mini-batch training. n_iter : int Number of iterations performed during training. t_train : float Total training time in seconds. trained : bool Indicates whether the dictionary has been trained. n_train : int Number of patches used during training. mode_traindl : int Training mode for SPAMS' `trainDL` function. modeD_traindl : int Dictionary mode for SPAMS' `trainDL` function. mode_lasso : int Mode for SPAMS' `lasso` function. identifier : str Optional identifier or note for distinguishing the dictionary. References ---------- .. [1] SPAMS (for python), (http://spams-devel.gforge.inria.fr/). Last accessed in October 2018. """
[docs] def __init__(self, dict_init=None, model=None, signal_pool=None, a_length=None, d_size=None, wave_pos=None, patch_min=1, l2_normed=True, allow_allzeros=False, random_state=None, ignore_completeness=False, lambda1=None, batch_size=64, n_iter=None, n_train=None, trained=False, mode_traindl=0, modeD_traindl=0, mode_lasso=2, identifier=''): """Initialize the dictionary. There are two ways to initialize the dictionary: 1. By directly providing the initial dictionary with `dict_init`. 2. By providing a collection of signals (`signal_pool`) from which atoms are randomly extracted to form the initial dictionary. If the second option is used, `a_length` and `d_size` must be explicitly specified to define the size of the dictionary. Additional optional parameters provide more control over this process. Parameters ---------- dict_init : ndarray of shape (d_size, a_length), optional Atoms of the initial dictionary. If `None`, `signal_pool` must be provided. model : dict, optional SPAMS' `trainDL` model components as a dictionary with elements {A, B, iter}. Must be provided if continuing training from a previous state. signal_pool : ndarray of shape (n_signals, n_samples), optional A collection of signals from which atoms are extracted to form the initial dictionary. Ignored if `dict_init` is provided. a_length : int, optional Length of each atom in the dictionary (patch size). Required if `signal_pool` is provided. d_size : int, optional Number of atoms in the dictionary. Required if `signal_pool` is provided. wave_pos : array-like of shape (n_signals, 2), optional Positions of waveforms within `signal_pool` to extract atoms from. If `None`, the entire array is used. patch_min : int, default=1 Minimum number of samples for each extracted patch. Ignored if `wave_pos` is `None`. l2_normed : bool, default=True If `True`, normalize extracted atoms to their L2 norm. allow_allzeros : bool, default=False By default, random atoms with all zeros are excluded from the initial dictionary. If `allow_allzeros=True`, they are allowed. random_state : int, optional Seed for random sampling from `signal_pool`. ignore_completeness : bool, optional, default=False If `False`, the dictionary must be overcomplete (`d_size > a_length`). lambda1 : float, optional Regularization parameter for training. batch_size : int, default=64 Batch size used during training. n_iter : int, optional Total number of iterations for training. If `None`, this must be set when calling the `train` method. n_train : int, optional Number of patches used for training. Informational only. trained : bool, default=False Indicates whether the dictionary is already trained. mode_traindl : int, default=0 Training mode for SPAMS' `trainDL` function. See SPAMS documentation. modeD_traindl : int, default=0 Dictionary mode for SPAMS' `trainDL` function. See SPAMS documentation. mode_lasso : int, default=2 Mode for SPAMS' `lasso` function. See SPAMS documentation. identifier : str, optional A note or label for identifying the dictionary. Notes ----- This method initializes the dictionary but does not train it. Use the `train` method for training. """ self.model = model self.dict_init = dict_init self.components = dict_init self.a_length = a_length self.d_size = d_size self.lambda1 = lambda1 self.batch_size = batch_size self.n_iter = n_iter self.t_train = -n_iter if n_iter is not None and n_iter < 0 else None self.trained = trained self.n_train = n_train self.mode_traindl = mode_traindl self.modeD_traindl = modeD_traindl self.mode_lasso = mode_lasso self.identifier = identifier self._check_initial_parameters(signal_pool, ignore_completeness) # Explicit initial dictionary (trained or not). if self.dict_init is not None: self.d_size, self.a_length = self.dict_init.shape # Get the initial atoms from a set of signals. else: self.dict_init = lib.extract_patches( signal_pool, patch_size=self.a_length, limits=wave_pos, n_patches=self.d_size, l2_normed=l2_normed, allow_allzeros=allow_allzeros, patch_min=patch_min, random_state=random_state ) self.components = self.dict_init
[docs] def train(self, patches, lambda1=None, n_iter=None, warm_start=False, verbose=False, threads=-1, **kwargs): """Train the dictionary. Train the dictionary with the given patches. This also allows a warm start using the previous components as initial dictionary, but only if the lambda1 parameter is the same. It can be thought of as adding more iterations to the training. Hence, providing different patches is discouraged and untested. Parameters ---------- patches : 2d-array(signals, samples) Training patches. lambda1 : float, optional Regularization parameter of the learning algorithm. It is not needed if already specified at initialization. n_iter : int, optional Total number of iterations to perform. If a negative number is provided it will perform the computation during the corresponding number of seconds. For instance `n_iter = -5` trains the dictionary during 5 seconds. warm_start : bool If True, use the previous components as initial dictionary. It can be thought of as adding more iterations to the training. Providing different patches is discouraged and untested. verbose : bool, optional If True print the iterations (might not be shown in real time). threads : int, optional Number of threads to use during training, see [1]. **kwargs Passed directly to 'spams.trainDL', see [1]. See Also -------- clawdia.lib.extract_patches : Useful for generating the training `patches`. """ if self.trained: if not warm_start: raise ValueError("the dictionary has already been trained") if lambda1 is not None and lambda1 != self.lambda1: raise ValueError("the 'lambda1' parameter must be the same " "as the one used at the previous training") if patches.shape[1] != self.a_length: raise ValueError("the length of 'patches' must be the same as the" " atoms of the dictionary") if n_iter is None: if self.n_iter is None: raise TypeError("'n_iter' not specified") else: n_iter = self.n_iter if lambda1 is None: if self.lambda1 is None: raise TypeError("'lambda1' not specified") lambda1 = self.lambda1 tic = time.time() components, model = spams.trainDL( patches.T, # SPAMS works with Fortran order. D=self.components.T, # model=self.model, batchsize=self.batch_size, K=self.d_size, # In SPAMS argo, the dictionary size is the number of atoms. lambda1=lambda1, iter=n_iter, mode=self.mode_traindl, modeD=self.modeD_traindl, verbose=verbose, numThreads=threads, return_model=True, **kwargs ) self.components = components.T self.model = model tac = time.time() if warm_start: if n_iter < 0: self.n_iter += model['iter'] self.t_train += -n_iter else: self.n_iter += n_iter self.t_train += tac - tic else: self.trained = True self.lambda1 = lambda1 self.n_train = patches.shape[0] if n_iter < 0: self.n_iter = model['iter'] self.t_train = -n_iter else: self.n_iter = n_iter self.t_train = tac - tic
def _reconstruct_single(self, signal, sc_lambda, step=1, **kwargs_lasso): # TODO: Add kwarg option to disable the patch normalization. # This might be usefull for tasks such detection or when a heavy # discrimination is needed. patches, norms = lib.extract_patches( signal, patch_size=self.a_length, step=step, l2_normed=True, return_norm_coefs=True ) code = spams.lasso( patches.T, # SPAMS works with Fortran order. D=self.components.T, # lambda1=sc_lambda, mode=self.mode_lasso, **kwargs_lasso ) patches = ((self.components.T @ code) * norms).T signal_rec = lib.reconstruct_from_patches_1d(patches, step) return signal_rec, code
[docs] def reconstruct(self, signal, sc_lambda, step=1, normed=True, with_code=False, **kwargs): """Reconstruct a signal as a sparse combination of dictionary atoms. Parameters ---------- signal : ndarray Sample to be reconstructed. sc_lambda : float Regularization parameter of the sparse coding transformation. step : int, 1 by default Sample interval between each patch extracted from signal. Determines the number of patches to be extracted. 1 by default. normed : boolean, True by default Normalize the result to the maximum absolute value. with_code : boolean, False by default. If True, also returns the coefficients array. **kwargs Passed directly to the external learning function. Returns ------- signal_rec : array Reconstructed signal. code : array(a_length, d_size), optional Transformed data, encoded as a sparse combination of atoms. Returned when 'with_code' is True. """ if not isinstance(signal, np.ndarray): raise TypeError("'signal' must be a numpy array") signal_rec, code = self._reconstruct_single(signal, sc_lambda, step, **kwargs) if normed and signal_rec.any(): norm = np.max(np.abs(signal_rec)) signal_rec /= norm code /= norm return (signal_rec, code) if with_code else signal_rec
def _reconstruct_batch(self, strains, *, sc_lambda, step=1, normed_windows=True, **kwargs): ns = strains.shape[0] patches, norms = lib.extract_patches( strains, patch_size=self.a_length, step=step, l2_normed=normed_windows, return_norm_coefs=True ) codes = spams.lasso( patches.T, # SPAMS works with Fortran order. D=self.components.T, # lambda1=sc_lambda, mode=self.mode_lasso, **kwargs ) patches = ((self.components.T @ codes) * norms).T lp = patches.shape[1] np_ = patches.shape[0] // ns # Number of patches per strain patches = patches.reshape(ns, np_, lp, order='C') reconstructions = np.empty_like(strains) for i in range(ns): reconstructions[i] = lib.reconstruct_from_patches_1d(patches[i], step) return reconstructions
[docs] def reconstruct_batch(self, signals, sc_lambda, out=None, step=1, normed=True, verbose=True, **kwargs): """TODO Reconstruct multiple signals, each one as a sparse combination of dictionary atoms. WARNING: Only viable for small 'signals' set, it is really memory expensive (all patches are stored in a single array in memory). WARNING: 'out' deprecated, left for backwards compatibility but will be ignored if given. """ out = self._reconstruct_batch(signals, sc_lambda=sc_lambda, step=step, **kwargs) if normed and out.any(): with np.errstate(divide='ignore', invalid='ignore'): out /= np.max(np.abs(out), axis=1, keepdims=True) np.nan_to_num(out, copy=False) return out
[docs] def reconstruct_minibatch(self, signals, *, sc_lambda, step=1, batchsize=4, normed=True, normed_windows=True, verbose=True, **kwargs): """TODO Reconstruct multiple signals, each one as a sparse combination of dictionary atoms. Minibatch version. """ n_signals = signals.shape[0] n_minibatch = n_signals // batchsize out = np.empty_like(signals) loop = range(n_minibatch) if verbose: loop = tqdm(loop) for ibatch in loop: i0 = ibatch * batchsize i1 = i0 + batchsize minibatch = signals[i0:i1] out[i0:i1] = self._reconstruct_batch( minibatch, sc_lambda=sc_lambda, step=step, normed_windows=normed_windows, **kwargs ) if n_minibatch == 0: # In case there was no point in using a minibatch: i1 = 0 # If 'n_signals' was not divisible by 'batchsize' reconstruct the # remaining signals: if i1 < n_signals: i0 = i1 minibatch = signals[i0:] out[i0:] = self._reconstruct_batch( minibatch, sc_lambda=sc_lambda, step=step, **kwargs ) if normed and out.any(): with np.errstate(divide='ignore', invalid='ignore'): out /= np.max(np.abs(out), axis=1, keepdims=True) np.nan_to_num(out, copy=False) return out
[docs] def reconstruct_margin_constrained( self, signal: NDArray, *, margin: int|tuple|list|NDArray, lambda_lims: tuple|list, step: int = 1, normed=True, full_output=False, kwargs_bisect={}, kwargs_lasso={} ) -> tuple[NDArray, NDArray, NDArray] | NDArray: """TODO""" if isinstance(margin, int): margin = signal[:margin] elif isinstance(margin, (tuple, list)): if len(margin) != 2: raise ValueError( "'margin': when given 'tuple' or 'list', it can contain " "only 2 integers." ) margin = signal[slice(*margin)] elif isinstance(margin, np.ndarray): pass else: raise TypeError( "'margin': expected types 'int', 'tuple' or 'NDArray'; " f"got '{type(margin).__name__}'." ) # Function to be bisected. def fun(sc_lambda): rec, _ = self._reconstruct_single(margin, sc_lambda, step, **kwargs_lasso) return np.sum(np.abs(rec)) try: with warnings.catch_warnings(): # Ignore specific warning from extract_patches since here we do # not care about reconstructing the entire strain (margin). warnings.filterwarnings("ignore", message="'signals' cannot be fully divided into patches.*") result = lib.semibool_bisect(fun, *lambda_lims, **kwargs_bisect) except lib.BoundaryError: rec = np.zeros_like(signal) code = None result = {'x': np.min(lambda_lims), 'f': 0., 'converged': False, 'niters': 0, 'funcalls': 2} else: rec, code = self._reconstruct_single(signal, result['x'], step, **kwargs_lasso) if normed and rec.any(): norm = np.max(np.abs(rec)) rec /= norm code /= norm return (rec, code, result) if full_output else rec
[docs] def reconstruct_iterative(self, signals, sc_lambda=0.01, step=1, batchsize=64, max_iter=100, threshold=0.001, normed=True, full_output=False, verbose=True, kwargs_lasso={}): """Reconstruct multiple signals using iterative residual subtraction. This method reconstructs each signal by iteratively updating and accumulating reconstructions. In the first iteration, the original input signal is reconstructed and then subtracted from itself to obtain the initial residual. In each subsequent iteration, a new reconstruction is generated from the current residual and subtracted from it, producing an updated residual for the next iteration, while also being added to the cumulative reconstruction. The process repeats until the Euclidean norm of the difference between consecutive residuals falls below a specified threshold, which sets the convergence criterion. NOTE: In contrast with the usual procedure, the windows into which each signal is split are not normalized. This is needed to enhance the dictionary discrimination. Otherwise, the residuals are amplified at each iteration, the algorithm takes longer to converge, and some ad-hoc tests showed it also messes up with the resulting shape. Parameters ---------- signals : ndarray Input signals to be reconstructed, with each signal along the first dimension. sc_lambda : float, optional Sparsity control parameter for reconstruction. step : int, optional Step size for the reconstruction. batchsize : int, optional Number of signals processed in each minibatch. max_iter : int, optional Maximum number of iterations before stopping. threshold : float, optional Convergence threshold based on the relative change in residuals. normed : bool, optional If True, the reconstructed signals are normalized after convergence. full_output : bool, optional If True, returns additional output values (residuals and iteration counts). verbose : bool, optional If True, prints progress information at each iteration. kwargs_lasso : dict, optional Additional arguments for the Lasso reconstruction method. Returns ------- ndarray or tuple The final reconstructed signals. If `full_output` is True, also returns the residuals and the number of iterations per signal. """ n_signals = signals.shape[0] # First iteration outside: if verbose: print(f"\nIteration 0") print(f"Signals remaining: {n_signals}") step_reconstructions = self.reconstruct_minibatch( signals, sc_lambda=sc_lambda, step=step, batchsize=batchsize, normed=False, # Normalization is (optionally) applied at the END. normed_windows=False, # See NOTE in the docstring. verbose=verbose, **kwargs_lasso ) final_reconstructions = step_reconstructions.copy() residuals = signals - step_reconstructions # Stop conditions iters = np.ones(n_signals, dtype=int) finished = ~step_reconstructions.any(axis=1) # In case any reconstructions are 0 already. residuals_old = residuals.copy() while not np.all(finished) and iters.max() < max_iter: if verbose: print(f"\nIteration {iters.max():3d}") print(f"Signals remaining: {(~finished).sum():^13d}") step_reconstructions = self.reconstruct_minibatch( residuals[~finished], sc_lambda=sc_lambda, step=step, batchsize=batchsize, normed=False, # Normalization is (optionally) applied at the END. normed_windows=False, # See NOTE in the docstring. verbose=verbose, **kwargs_lasso ) final_reconstructions[~finished] += step_reconstructions residuals[~finished] -= step_reconstructions # Stop conditions iters[~finished] += 1 residual_decrease = np.linalg.norm(residuals[~finished] - residuals_old[~finished], axis=1) finished[~finished] = residual_decrease < threshold residuals_old = residuals.copy() if verbose: print( "CURRENT RESIDUAL DECREASE:\n" f"Max: {residual_decrease.max()}\n" f"Mean: {residual_decrease.mean()}\n" f"Min: {residual_decrease.min()}\n" ) if not np.all(finished): print("WARNING: reached max_iter before finishing all the reconstructions") if normed: with np.errstate(divide='ignore', invalid='ignore'): final_reconstructions /= np.max(np.abs(final_reconstructions), axis=1, keepdims=True) np.nan_to_num(final_reconstructions, copy=False) return (final_reconstructions, residuals, iters) if full_output else final_reconstructions
[docs] def reconstruct_loss_optimised( self, strain, *, reference, step=1, limits=None, loss_func='match', normed=True, kwargs_minimize={ 'method': 'bounded', 'bounds': (-2,1), 'options': {'maxiter': 100, 'xatol': 0.04} }, kwargs_lasso={}, verbose=False ): """Find the best reconstruction of a signal w.r.t. a reference. Find the lambda which produces a reconstruction of the input 'strain' closest to the given 'reference', according to a chosen loss function: Match, Overlap, SSIM, or a custom one. The minimisation is performed by SciPy's 'minimize_scalar', with options specified through `kwargs_minimize`. PARAMETERS ---------- strain: ndarray Input strain to be reconstructed (and optimized). reference: ndarray Reference strain which to compare the reconstruction to. step: int, optional Separation in samples between each window into which the input strain is split up to be reconstructed by the dictionary. Defaults to 1. limits: array-like, optional Indices of limits to where compute the loss between the reconstruction and the reference strain. loss_func: str | callable, optional If 'str', can be 'match' (default), 'overlap' or 'ssim'. In both cases, their pseudo-distance is used. Refer to their documentation in 'clawdia.estimators' for more details. If 'callable', it must be a symmetric function of 2 arguments, over whose the 'reference' signal and the denoised signal will be passed. It must return a distance-like score between 0 (best) and 1 (worst) to guide the minimisation algorithm. normed: bool, optional If True, returns the signal normed to its maximum absolute amplitude. kwargs_minimize: dict Passed to SciPy's `minimize_scalar(**kwargs_minimize)`. Bracket or boundary values must be passed as `np.log10(bounds)`. kwargs_lasso: dict, optional Passed to Python-Spams' `lasso(**kwargs_lasso)`. verbose: bool, optional Set the maximum verbosity (`'disp': 3`) to SciPy's `minimize_scalar` and print info about the minimization results. False by default. RETURNS ------- rec: ndarray Optimum reconstruction found. l_opt: float Optimum value for lambda. loss: float dOverlap `(1 - Overlap)/2` or DSSIM `(1 - SSIM)/2` between the optimized reconstruction and the reference. """ # Trim strain and reference if limits are specified. sl = slice(None) if limits is None else slice(*limits) reference_ = reference[sl] strain_ = strain[sl] # Set the loss function: if loss_func == 'match': lossf = lambda x: estimators.imatch(x, reference_) _worst = 1.0 _eps = 1e-2 # start a hair above worst elif loss_func == 'ssim': lossf = lambda x: estimators.dssim(x, reference_) _worst = 0.5 _eps = 1e-2 # start a hair above worst elif loss_func == 'overlap': lossf = lambda x: estimators.doverlap(x, reference_) _worst = 0.5 # dOverlap in [0, ~0.5] in practice; theoretical worst = 1.0 for full anticorrelation _eps = 1e-3 # keep this very close to 0.5 elif callable(loss_func): lossf = lambda x: loss_func(x, reference_) _worst = 1.0 # conservative default _eps = 1e-2 else: raise ValueError(f"loss function '{loss_func}' not implemented") # Extract log10-lambda search bounds if provided; else default to [0.01, 10] => [-2, 1] log_bounds = kwargs_minimize.get('bounds', (np.log10(0.01), np.log10(10.0))) if log_bounds[0] > log_bounds[1]: log_bounds = log_bounds[::-1] log_l_min, log_l_max = float(log_bounds[0]), float(log_bounds[1]) if not np.isfinite(log_l_min) or not np.isfinite(log_l_max): # Fall back log_l_min, log_l_max = -2.0, 1.0 # Build a gentle positive slope so the penalty increases with # log10(lambda). Make it span a modest delta across the search range; # this keeps it dominant only when rec==0. For dOverlap keep the rise # small; for DSSIM allow a bit more headroom. span = 0.10 if loss_func == 'overlap' else 0.25 slope = span / (log_l_max - log_l_min) intercept = _worst + _eps # start just above the worst value def _null_penalty(l_rec_log: float) -> float: # Affine function in log-space: at log_l_min it's ~worst+eps and # then rises linearly return intercept + slope * (l_rec_log - log_l_min) rec = None def cost_function(l_rec_log): """Function to be minimized.""" nonlocal rec l_rec = 10 ** l_rec_log # Opitimizes lambda in log. space! rec = self.reconstruct(strain_, l_rec, step=step, normed=normed, **kwargs_lasso) if rec.any(): return lossf(rec) # If reconstruction is identically zero, steer away from larger lambda: return _null_penalty(l_rec_log) if verbose: # Add maximum verbosity to `scipy.optimize.minimize_scalar`, unless # the `disp` option is present. if 'options' in kwargs_minimize: if 'disp' not in kwargs_minimize['options']: kwargs_minimize['options']['disp'] = 3 else: kwargs_minimize['options'] = {'disp': 3} result = scipy.optimize.minimize_scalar(cost_function, **kwargs_minimize) l_opt = 10 ** result['x'] loss = result['fun'] # If the section to be optimised was shorter than the whole strain, # we need to reconstruct the whole strain with the found lambda. # Otherwise, we keep the `rec` value directly from the cost function. if limits is not None: if verbose: print("Reconstructing the whole strain with the optimal lambda found.") rec = self.reconstruct(strain, l_opt, step=step, normed=normed, **kwargs_lasso) if verbose: success = result['success'] print( "Optimization results:\n" f"> Minimization success: {success}" ) if not success: print( " Reason\n" " ------\n" + result['message'] + "\n" " ------" ) print( f"> Lambda optimized: {l_opt}\n" f"> Iterations performed: {result['nit']}\n" f"> Final loss: {loss}" ) return rec, l_opt, loss
[docs] def save(self, file): """Save the current state of the DictionarySpams object to a file. This method saves all attributes of the object as a `.npz` file. If the object has not been trained, certain attributes (`lambda1`, `n_train`, and `t_train`) are removed to avoid potential issues when reloading the state. Parameters ---------- file : str or file-like object The file path or file object where the state of the object will be saved. If a string is provided, it specifies the path to the `.npz` file. If a file-like object is given, it must be writable in binary mode. """ vars_ = vars(self) to_remove = [] if not self.trained: # To avoid silent bugs in the future to_remove += ['lambda1', 'n_train', 't_train'] for attr in to_remove: vars_.pop(attr) np.savez(file, **vars_)
[docs] def copy(self): """Return a copy of the dictionary. Returns a new instance of the same dictionary with the same values and state. Returns ------- dico_copy : DictionarySpams A copy of the current dictionary. """ dico_copy = DictionarySpams( dict_init=self.components.copy(), model=self.model, lambda1=self.lambda1, batch_size=self.batch_size, identifier=self.identifier, n_iter=self.n_iter, n_train=self.t_train, trained=self.trained, mode_traindl=self.mode_traindl, modeD_traindl=self.modeD_traindl, mode_lasso=self.mode_lasso ) if self.trained: # Retain the initial components of the dictionary. dico_copy.dict_init = self.dict_init return dico_copy
[docs] def reset(self): """Reset the dictionary to its initial (untrained) state.""" self.components = self.dict_init self.trained = False self.n_train = None self.t_train = None
def _check_initial_parameters(self, signal_pool, ignore_completeness): # Explicit initial dictionary. if self.dict_init is not None: if not isinstance(self.dict_init, np.ndarray): raise TypeError( f"'{type(self.dict_init).__name__}' is not a valid 'dict_init'" ) if not self.dict_init.flags.c_contiguous: raise ValueError("'dict_init' must be a C-contiguous array") if (self.dict_init.shape[1] >= self.dict_init.shape[0] and not ignore_completeness): raise ValueError("the dictionary must be overcomplete (d_size > a_length)") # Signal pool from where to extract the initial dictionary. elif signal_pool is not None: if not isinstance(signal_pool, np.ndarray): raise TypeError( f"'{type(signal_pool).__name__}' is not a valid 'signal_pool'" ) if not signal_pool.flags.c_contiguous: raise ValueError("'signal_pool' must be a C-contiguous array") if None in (self.a_length, self.d_size): raise TypeError( f"'a_length' and 'd_size' must be explicitly provided along 'signal_pool'" ) if (self.a_length >= self.d_size) and not ignore_completeness: raise ValueError("the dictionary must be overcomplete (d_size > a_length)") else: raise ValueError("either 'dict_init' or 'signal_pool' must be provided")