Skip to content

Filter

labchain.base.base_clases.BaseFilter

Bases: BasePlugin

Base class for filter components in the framework.

This abstract class extends BasePlugin and provides a structure for implementing filter operations, including fit and predict methods. It serves as the foundation for all filter types in the framework, ensuring consistent behavior and interfaces for machine learning operations.

Key Features
  • Implements fit and predict methods for machine learning operations
  • Provides caching mechanisms for model and data storage
  • Supports verbose output for debugging and monitoring
  • Implements equality and hashing methods for filter comparison
  • Supports serialization and deserialization of filter instances
Usage

To create a new filter type, inherit from this class and implement the required methods. For example:

class MyCustomFilter(BaseFilter):
    def __init__(self, n_components: int = 2):
        super().__init__(n_components=n_components)
        self._model = None  # Private: internal state

    def fit(self, x: XYData, y: Optional[XYData] = None) -> None:
        self._print_acction("Fitting MyCustomFilter")
        # Implement fitting logic here
        data = x.value
        self._model = np.linalg.svd(data - np.mean(data, axis=0), full_matrices=False)

    def predict(self, x: XYData) -> XYData:
        self._print_acction("Predicting with MyCustomFilter")
        if self._model is None:
            raise ValueError("Model not fitted yet.")
        # Implement prediction logic here
        data = x.value
        U, s, Vt = self._model
        transformed = np.dot(data - np.mean(data, axis=0), Vt.T[:, :self.n_components])
        return XYData(_value=transformed, _hash=x._hash, _path=self._m_path)

Attributes:

Name Type Description
_verbose bool

Controls the verbosity of output.

_m_hash str

Hash of the current model.

_m_str str

String representation of the current model.

_m_path str

Path to the current model.

_original_fit method

Reference to the original fit method.

_original_predict method

Reference to the original predict method.

Methods:

Name Description
__init__

Initializes the filter instance, setting up attributes and method wrappers.

fit

XYData, y: Optional[XYData]) -> Optional[float]: Fits the filter to the input data.

predict

XYData) -> XYData: Makes predictions using the fitted filter.

verbose

bool) -> None: Sets the verbosity level for output.

init

Initializes filter-specific attributes.

_get_model_key

str) -> Tuple[str, str]: Generates a unique key for the model.

_get_data_key

str, data_hash: str) -> Tuple[str, str]: Generates a unique key for the data.

grid

Dict[str, List[Any] | Tuple[Any, Any]]) -> BaseFilter: Sets up grid search parameters.

unwrap

Returns the base filter without any wrappers.

Note

This is an abstract base class. Concrete implementations should override the fit and predict methods to provide specific functionality.

Source code in labchain/base/base_clases.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
class BaseFilter(BasePlugin):
    """
    Base class for filter components in the framework.

    This abstract class extends BasePlugin and provides a structure for implementing
    filter operations, including fit and predict methods. It serves as the foundation
    for all filter types in the framework, ensuring consistent behavior and interfaces
    for machine learning operations.

    Key Features:
        - Implements fit and predict methods for machine learning operations
        - Provides caching mechanisms for model and data storage
        - Supports verbose output for debugging and monitoring
        - Implements equality and hashing methods for filter comparison
        - Supports serialization and deserialization of filter instances

    Usage:
        To create a new filter type, inherit from this class and implement
        the required methods. For example:
        ```python
        class MyCustomFilter(BaseFilter):
            def __init__(self, n_components: int = 2):
                super().__init__(n_components=n_components)
                self._model = None  # Private: internal state

            def fit(self, x: XYData, y: Optional[XYData] = None) -> None:
                self._print_acction("Fitting MyCustomFilter")
                # Implement fitting logic here
                data = x.value
                self._model = np.linalg.svd(data - np.mean(data, axis=0), full_matrices=False)

            def predict(self, x: XYData) -> XYData:
                self._print_acction("Predicting with MyCustomFilter")
                if self._model is None:
                    raise ValueError("Model not fitted yet.")
                # Implement prediction logic here
                data = x.value
                U, s, Vt = self._model
                transformed = np.dot(data - np.mean(data, axis=0), Vt.T[:, :self.n_components])
                return XYData(_value=transformed, _hash=x._hash, _path=self._m_path)
        ```

    Attributes:
        _verbose (bool): Controls the verbosity of output.
        _m_hash (str): Hash of the current model.
        _m_str (str): String representation of the current model.
        _m_path (str): Path to the current model.
        _original_fit (method): Reference to the original fit method.
        _original_predict (method): Reference to the original predict method.

    Methods:
        __init__(verbose=True, *args, **kwargs):
            Initializes the filter instance, setting up attributes and method wrappers.

        fit(x: XYData, y: Optional[XYData]) -> Optional[float]:
            Fits the filter to the input data.

        predict(x: XYData) -> XYData:
            Makes predictions using the fitted filter.

        verbose(value: bool) -> None:
            Sets the verbosity level for output.

        init() -> None:
            Initializes filter-specific attributes.

        _get_model_key(data_hash: str) -> Tuple[str, str]:
            Generates a unique key for the model.

        _get_data_key(model_str: str, data_hash: str) -> Tuple[str, str]:
            Generates a unique key for the data.

        grid(grid: Dict[str, List[Any] | Tuple[Any, Any]]) -> BaseFilter:
            Sets up grid search parameters.

        unwrap() -> BaseFilter:
            Returns the base filter without any wrappers.

    Note:
        This is an abstract base class. Concrete implementations should override
        the fit and predict methods to provide specific functionality.
    """

    def _print_acction(self, action: str) -> None:
        """
        Print an action message with formatting.

        This method is used for verbose output to indicate the current action being performed.

        Args:
            action (str): The action message to be printed.

        Returns:
            None
        """
        s_str = "_" * 100
        s_str += f"\n{action}...\n"
        s_str += "*" * 100

        if self._verbose:
            rprint(s_str)

    def verbose(self, value: bool) -> None:
        """
        Set the verbosity of the filter.

        Args:
            value (bool): If True, enables verbose output; if False, disables it.

        Returns:
            None
        """
        self._verbose = value

    def __init__(self, verbose=True, *args: Any, **kwargs: Any):
        """
        Initialize the BaseFilter instance.

        This method sets up attributes for storing model-related information and wraps
        the fit and predict methods with pre-processing steps.

        Args:
            verbose (bool, optional): If True, enables verbose output. Defaults to True.
            *args (Any): Variable length argument list.
            **kwargs (Any): Arbitrary keyword arguments.
        """
        self._verbose = verbose
        self._original_fit = self.fit
        self._original_predict = self.predict

        # Replace fit and predict methods - use __dict__ directly to avoid __setattr__
        if hasattr(self, "fit"):
            self.__dict__["fit"] = self._pre_fit_wrapp
        if hasattr(self, "predict"):
            self.__dict__["predict"] = self._pre_predict_wrapp

        super().__init__(*args, **kwargs)

        m_hash, m_str = self._get_model_key(data_hash=" , ")

        self._m_hash: str = m_hash
        self._m_str: str = m_str
        self._m_path: str = f"{self._get_model_name()}/{m_hash}"

    def __eq__(self, other: object) -> bool | NotImplementedType:
        """
        Check equality between this filter and another object.

        Two filters are considered equal if they are of the same type and have the same public attributes.

        Args:
            other (object): The object to compare with this filter.

        Returns:
            bool: True if the objects are equal, False otherwise.
        """
        if not isinstance(other, BaseFilter):
            return NotImplemented
        return (
            type(self) is type(other)
            and self._public_attributes == other._public_attributes
        )

    def __hash__(self) -> int:
        """
        Generate a hash value for this filter.

        The hash is based on the filter's type and its public attributes.

        Returns:
            int: The hash value of the filter.
        """
        return hash((type(self), frozenset(self._public_attributes.items())))

    def _pre_fit(self, x: XYData, y: Optional[XYData] = None) -> Tuple[str, str, str]:
        """
        Perform pre-processing steps before fitting the model.

        This method generates and sets the model hash, path, and string representation.

        Args:
            x (XYData): The input data.
            y (Optional[XYData]): The target data, if applicable.

        Returns:
            Tuple[str, str, str]: A tuple containing the model hash, path, and string representation.
        """
        m_hash, m_str = self._get_model_key(
            data_hash=f'{x._hash}, {y._hash if y is not None else ""}'
        )
        m_path = f"{self._get_model_name()}/{m_hash}"

        print(f"Calling prefit on {self.__class__.__name__}")

        self._m_hash = m_hash
        self._m_path = m_path
        self._m_str = m_str
        return m_hash, m_path, m_str

    def _pre_predict(self, x: XYData) -> XYData:
        """
        Perform pre-processing steps before making predictions.

        This method generates a new XYData object with updated hash and path.

        Args:
            x (XYData): The input data for prediction.

        Returns:
            XYData: A new XYData object with updated hash and path.

        Raises:
            ValueError: If the model has not been trained or loaded.
        """
        try:
            d_hash, _ = self._get_data_key(self._m_str, x._hash)

            new_x = XYData(
                _hash=d_hash,
                _value=x._value,
                _path=f"{self._get_model_name()}/{self._m_hash}",
            )

            return new_x

        except Exception:
            raise ValueError("Trainable filter model not trained or loaded")

    def _pre_fit_wrapp(
        self, x: XYData, y: Optional[XYData] = None
    ) -> Optional[float | dict]:
        """
        Wrapper method for the fit function.

        This method performs pre-processing steps before calling the original fit method.

        Args:
            x (XYData): The input data.
            y (Optional[XYData]): The target data, if applicable.

        Returns:
            Optional[float]: The result of the original fit method.
        """
        m_hash = self._m_hash
        m_path = self._m_path
        m_str = self._m_str
        try:
            self._pre_fit(x, y)
            res = self._original_fit(x, y)
        except Exception as e:
            self._m_hash = m_hash
            self._m_path = m_path
            self._m_str = m_str
            raise e
        return res

    def _pre_predict_wrapp(self, x: XYData) -> XYData:
        """
        Wrapper method for the predict function.

        This method performs pre-processing steps before calling the original predict method.

        Args:
            x (XYData): The input data for prediction.

        Returns:
            XYData: The prediction results with updated hash and path.
        """
        new_x = self._pre_predict(x)
        return XYData(
            _hash=new_x._hash,
            _path=new_x._path,
            _value=self._original_predict(x)._value,
        )

    def __getstate__(self) -> Dict[str, Any]:
        """
        Prepare the object for pickling.

        This method ensures that the original fit and predict methods are stored for serialization.

        Returns:
            Dict[str, Any]: The object's state dictionary.
        """
        state = super().__getstate__()
        # Ensure we're storing the original methods for serialization
        state["fit"] = self._original_fit
        state["predict"] = self._original_predict
        return state

    def __setstate__(self, state: Dict[str, Any]):
        """
        Restore the object from its pickled state.

        This method restores the wrapper methods after deserialization.

        Args:
            state (Dict[str, Any]): The pickled state of the object.
        """
        super().__setstate__(state)
        # Restore the wrapper methods after deserialization
        self.__dict__["fit"] = self._pre_fit_wrapp
        self.__dict__["predict"] = self._pre_predict_wrapp

    def fit(self, x: XYData, y: Optional[XYData]) -> Optional[float | dict]:
        """
        Method for fitting the filter to the data.

        This method should be overridden by subclasses to implement specific fitting logic.

        Args:
            x (XYData): The input data.
            y (Optional[XYData]): The target data, if applicable.

        Returns:
            Optional[float]: An optional float value, typically used for metrics or loss.

        Raises:
            NotTrainableFilterError: If the filter does not support fitting.
        """
        raise NotTrainableFilterError("This filter does not support fitting.")

    @abstractmethod
    def predict(self, x: XYData) -> XYData:
        """
        Abstract method for making predictions using the filter.

        This method must be implemented by subclasses to provide specific prediction logic.

        Args:
            x (XYData): The input data.

        Returns:
            XYData: The prediction results.
        """
        ...

    def _get_model_name(self) -> str:
        """
        Get the name of the model.

        Returns:
            str: The name of the model (class name).
        """
        return self.__class__.__name__

    def _get_model_key(self, data_hash: str) -> Tuple[str, str]:
        """
        Generate a unique key for the model based on its parameters and input data.

        Args:
            data_hash (str): A hash representing the input data.

        Returns:
            Tuple[str, str]: A tuple containing the model hash and a string representation.
        """
        model_str = f"<{self.item_dump(exclude=set('extra_params'))}>({data_hash})"
        model_hashcode = hashlib.sha1(model_str.encode("utf-8")).hexdigest()
        return model_hashcode, model_str

    def _get_data_key(self, model_str: str, data_hash: str) -> Tuple[str, str]:
        """
        Generate a unique key for the data based on the model and input data.

        Args:
            model_str (str): A string representation of the model.
            data_hash (str): A hash representing the input data.

        Returns:
            Tuple[str, str]: A tuple containing the data hash and a string representation.
        """
        data_str = f"{model_str}.predict({data_hash})"
        data_hashcode = hashlib.sha1(data_str.encode("utf-8")).hexdigest()
        return data_hashcode, data_str

    def grid(self, grid: Dict[str, List[Any] | Tuple[Any, Any] | dict]) -> BaseFilter:
        """
        Set up grid search parameters for the filter.

        This method allows defining a grid of hyperparameters for optimization.

        Args:
            grid (Dict[str, List[Any] | Tuple[Any, Any]]): A dictionary where keys are parameter names
                and values are lists or tuples of possible values.

        Returns:
            BaseFilter: The filter instance with grid search parameters set.
        """
        self._grid = grid
        return self

    def unwrap(self) -> BaseFilter:
        """
        Return the base filter without any wrappers.

        This method is useful when you need to access the original filter without any
        additional layers or modifications added by wrappers.

        Returns:
            BaseFilter: The unwrapped base filter.
        """
        return self

    @staticmethod
    def clear_memory():
        import gc

        gc.collect()
        try:
            import torch

            torch.cuda.empty_cache()
        except ImportError:
            pass

__eq__(other)

Check equality between this filter and another object.

Two filters are considered equal if they are of the same type and have the same public attributes.

Parameters:

Name Type Description Default
other object

The object to compare with this filter.

required

Returns:

Name Type Description
bool bool | NotImplementedType

True if the objects are equal, False otherwise.

Source code in labchain/base/base_clases.py
def __eq__(self, other: object) -> bool | NotImplementedType:
    """
    Check equality between this filter and another object.

    Two filters are considered equal if they are of the same type and have the same public attributes.

    Args:
        other (object): The object to compare with this filter.

    Returns:
        bool: True if the objects are equal, False otherwise.
    """
    if not isinstance(other, BaseFilter):
        return NotImplemented
    return (
        type(self) is type(other)
        and self._public_attributes == other._public_attributes
    )

__getstate__()

Prepare the object for pickling.

This method ensures that the original fit and predict methods are stored for serialization.

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The object's state dictionary.

Source code in labchain/base/base_clases.py
def __getstate__(self) -> Dict[str, Any]:
    """
    Prepare the object for pickling.

    This method ensures that the original fit and predict methods are stored for serialization.

    Returns:
        Dict[str, Any]: The object's state dictionary.
    """
    state = super().__getstate__()
    # Ensure we're storing the original methods for serialization
    state["fit"] = self._original_fit
    state["predict"] = self._original_predict
    return state

__hash__()

Generate a hash value for this filter.

The hash is based on the filter's type and its public attributes.

Returns:

Name Type Description
int int

The hash value of the filter.

Source code in labchain/base/base_clases.py
def __hash__(self) -> int:
    """
    Generate a hash value for this filter.

    The hash is based on the filter's type and its public attributes.

    Returns:
        int: The hash value of the filter.
    """
    return hash((type(self), frozenset(self._public_attributes.items())))

__init__(verbose=True, *args, **kwargs)

Initialize the BaseFilter instance.

This method sets up attributes for storing model-related information and wraps the fit and predict methods with pre-processing steps.

Parameters:

Name Type Description Default
verbose bool

If True, enables verbose output. Defaults to True.

True
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in labchain/base/base_clases.py
def __init__(self, verbose=True, *args: Any, **kwargs: Any):
    """
    Initialize the BaseFilter instance.

    This method sets up attributes for storing model-related information and wraps
    the fit and predict methods with pre-processing steps.

    Args:
        verbose (bool, optional): If True, enables verbose output. Defaults to True.
        *args (Any): Variable length argument list.
        **kwargs (Any): Arbitrary keyword arguments.
    """
    self._verbose = verbose
    self._original_fit = self.fit
    self._original_predict = self.predict

    # Replace fit and predict methods - use __dict__ directly to avoid __setattr__
    if hasattr(self, "fit"):
        self.__dict__["fit"] = self._pre_fit_wrapp
    if hasattr(self, "predict"):
        self.__dict__["predict"] = self._pre_predict_wrapp

    super().__init__(*args, **kwargs)

    m_hash, m_str = self._get_model_key(data_hash=" , ")

    self._m_hash: str = m_hash
    self._m_str: str = m_str
    self._m_path: str = f"{self._get_model_name()}/{m_hash}"

__setstate__(state)

Restore the object from its pickled state.

This method restores the wrapper methods after deserialization.

Parameters:

Name Type Description Default
state Dict[str, Any]

The pickled state of the object.

required
Source code in labchain/base/base_clases.py
def __setstate__(self, state: Dict[str, Any]):
    """
    Restore the object from its pickled state.

    This method restores the wrapper methods after deserialization.

    Args:
        state (Dict[str, Any]): The pickled state of the object.
    """
    super().__setstate__(state)
    # Restore the wrapper methods after deserialization
    self.__dict__["fit"] = self._pre_fit_wrapp
    self.__dict__["predict"] = self._pre_predict_wrapp

clear_memory() staticmethod

Source code in labchain/base/base_clases.py
@staticmethod
def clear_memory():
    import gc

    gc.collect()
    try:
        import torch

        torch.cuda.empty_cache()
    except ImportError:
        pass

fit(x, y)

Method for fitting the filter to the data.

This method should be overridden by subclasses to implement specific fitting logic.

Parameters:

Name Type Description Default
x XYData

The input data.

required
y Optional[XYData]

The target data, if applicable.

required

Returns:

Type Description
Optional[float | dict]

Optional[float]: An optional float value, typically used for metrics or loss.

Raises:

Type Description
NotTrainableFilterError

If the filter does not support fitting.

Source code in labchain/base/base_clases.py
def fit(self, x: XYData, y: Optional[XYData]) -> Optional[float | dict]:
    """
    Method for fitting the filter to the data.

    This method should be overridden by subclasses to implement specific fitting logic.

    Args:
        x (XYData): The input data.
        y (Optional[XYData]): The target data, if applicable.

    Returns:
        Optional[float]: An optional float value, typically used for metrics or loss.

    Raises:
        NotTrainableFilterError: If the filter does not support fitting.
    """
    raise NotTrainableFilterError("This filter does not support fitting.")

grid(grid)

Set up grid search parameters for the filter.

This method allows defining a grid of hyperparameters for optimization.

Parameters:

Name Type Description Default
grid Dict[str, List[Any] | Tuple[Any, Any]]

A dictionary where keys are parameter names and values are lists or tuples of possible values.

required

Returns:

Name Type Description
BaseFilter BaseFilter

The filter instance with grid search parameters set.

Source code in labchain/base/base_clases.py
def grid(self, grid: Dict[str, List[Any] | Tuple[Any, Any] | dict]) -> BaseFilter:
    """
    Set up grid search parameters for the filter.

    This method allows defining a grid of hyperparameters for optimization.

    Args:
        grid (Dict[str, List[Any] | Tuple[Any, Any]]): A dictionary where keys are parameter names
            and values are lists or tuples of possible values.

    Returns:
        BaseFilter: The filter instance with grid search parameters set.
    """
    self._grid = grid
    return self

predict(x) abstractmethod

Abstract method for making predictions using the filter.

This method must be implemented by subclasses to provide specific prediction logic.

Parameters:

Name Type Description Default
x XYData

The input data.

required

Returns:

Name Type Description
XYData XYData

The prediction results.

Source code in labchain/base/base_clases.py
@abstractmethod
def predict(self, x: XYData) -> XYData:
    """
    Abstract method for making predictions using the filter.

    This method must be implemented by subclasses to provide specific prediction logic.

    Args:
        x (XYData): The input data.

    Returns:
        XYData: The prediction results.
    """
    ...

unwrap()

Return the base filter without any wrappers.

This method is useful when you need to access the original filter without any additional layers or modifications added by wrappers.

Returns:

Name Type Description
BaseFilter BaseFilter

The unwrapped base filter.

Source code in labchain/base/base_clases.py
def unwrap(self) -> BaseFilter:
    """
    Return the base filter without any wrappers.

    This method is useful when you need to access the original filter without any
    additional layers or modifications added by wrappers.

    Returns:
        BaseFilter: The unwrapped base filter.
    """
    return self

verbose(value)

Set the verbosity of the filter.

Parameters:

Name Type Description Default
value bool

If True, enables verbose output; if False, disables it.

required

Returns:

Type Description
None

None

Source code in labchain/base/base_clases.py
def verbose(self, value: bool) -> None:
    """
    Set the verbosity of the filter.

    Args:
        value (bool): If True, enables verbose output; if False, disables it.

    Returns:
        None
    """
    self._verbose = value