Source code for clawdia.pipeline

"""Basic implementation of the pipeline model.

The `Pipeline` class provides a minimal example of how to use CLAWDIA as a 
classification pipeline. This implementation assumes that the dictionaries 
have already been trained and that all necessary hyperparameters and 
post-training parameters are provided.

.. warning::

    This module is still under development and has not been thoroughly tested. 
    API changes or unexpected behavior may occur in future updates.

"""


[docs] class Pipeline: """ Implements a basic classification pipeline that preprocesses input gravitational-wave strain data using a denoising dictionary and subsequently classifies it using a classification dictionary. """
[docs] def __init__(self, *, dico_den, dico_den_params, dico_clas, dico_clas_params): if dico_den.components.shape[0] > dico_clas.D.shape[0]: raise ValueError( "the length of the atoms of the denoising dictionary must be shorter" " or equal to the length of the atoms of the classification dictionary" ) # Load settings and dictionaries. self.dico_den = dico_den self.dico_den_params = dico_den_params self.dico_clas = dico_clas self.dico_clas_params = dico_clas_params
def __call__(self, strains, with_losses=False, with_preprocessed=False): chewed = self._preprocess(strains) results = self._predict(chewed, with_losses=with_losses) if with_preprocessed: results += (chewed,) return results def _preprocess(self, strains): # Denoise + norm pps = self.dico_den.reconstruct_minibatch(strains.T, normed=True, **self.dico_den_params).T return pps def _predict(self, strains, with_losses=False): results = self.dico_clas.predict(strains, with_losses=with_losses, **self.dico_clas_params) return results