from framework3.utils.patch_type_guard import patch_inspect_for_notebooks
patch_inspect_for_notebooks()
✅ Patched inspect.getsource using dill.
First we import and prepare the data. We meed to use XYData class to track the data transformations in the pipeline. This let hash each data versión and provide caché captabilities.
from framework3.base import XYData
from sklearn import datasets
# Load iris dataset and convert it to XYData format
iris = datasets.load_iris()
X = XYData(
_hash="Iris X data",
_path="/datasets",
_value=iris.data, # type: ignore
)
y = XYData(
_hash="Iris y data",
_path="/datasets",
_value=iris.target, # type: ignore
)
Now we can define our custom filter class.
from typing import Optional
from sklearn.linear_model import LogisticRegression
from framework3.base import BaseFilter, XYData
from framework3 import Container
@Container.bind()
class CustomLogisticRegresion(BaseFilter):
def __init__(self, threshold: float = 0.5):
super().__init__()
self.threshold = threshold
# Non configuration attributes should be private
self._model = LogisticRegression()
def fit(self, x: XYData, y: Optional[XYData]) -> None:
X = x.value
if y is not None:
self._model.fit(X, y.value)
else:
raise ValueError("y must be provided for training")
def predict(self, x: XYData) -> XYData:
X = x.value
probabilities = self._model.predict_proba(X)[:, 1]
predictions = (probabilities > self.threshold).astype(int)
# We have to wrap the output class with a Mock XYData object
# The framework will update the attributs with the new hash data.
return XYData.mock(predictions)
Now we want to use this filter in our pipeline. We also will add a PCA filter and we will set several metrics: F1, Precision, and Recall.
from framework3 import F1, F3Pipeline, Precission, Recall
from framework3.plugins.filters import PCAPlugin
pipeline = F3Pipeline(
filters=[PCAPlugin(n_components=2), CustomLogisticRegresion()],
metrics=[F1(), Precission(), Recall()],
)
/home/manuel.couto.pintos/Documents/code/framework3/framework3/base/base_clases.py:56: InstrumentationWarning: instrumentor did not find the target function -- not typechecking __main__.CustomLogisticRegresion.__init__ cls.__init__ = typechecked(init_method) /home/manuel.couto.pintos/Documents/code/framework3/framework3/base/base_clases.py:64: InstrumentationWarning: instrumentor did not find the target function -- not typechecking __main__.CustomLogisticRegresion.fit setattr(cls, attr_name, typechecked(attr_value)) /home/manuel.couto.pintos/Documents/code/framework3/framework3/base/base_clases.py:64: InstrumentationWarning: instrumentor did not find the target function -- not typechecking __main__.CustomLogisticRegresion.predict setattr(cls, attr_name, typechecked(attr_value))
Note that we have some warnings due to type hinting. These are due to a limitation of typecheker in jupyter notebooks. It'll will be fixed them in the next versions.
pipeline.fit(X, y)
_y = pipeline.predict(X)
pipeline.evaluate(X, y, _y)
____________________________________________________________________________________________________
Fitting pipeline...
****************************************************************************************************
*PCAPlugin({'n_components': 2})
*CustomLogisticRegresion({'threshold': 0.5})
____________________________________________________________________________________________________
Predicting pipeline...
****************************************************************************************************
*PCAPlugin({'n_components': 2})
*CustomLogisticRegresion({'threshold': 0.5})
____________________________________________________________________________________________________
Evaluating pipeline......
****************************************************************************************************
{'F1': 0.5372488683746962, 'Precission': 0.4847443928066276, 'Recall': 0.6466666666666666}
We can apreciate that results are not the best posible with this basic example, but it gives you a starting point. To get better results, you should tune the hyperparameters of your models, preprocess your data, and add more features.