Sklearn Estimator
framework3.utils.skestimator
¶
SkWrapper
¶
Bases: BaseEstimator
A wrapper class for BaseFilter that implements scikit-learn's BaseEstimator interface.
This class allows BaseFilter objects to be used with scikit-learn's GridSearchCV and other scikit-learn compatible tools.
Key Features
- Wraps any BaseFilter subclass to make it compatible with scikit-learn
- Implements fit, predict, and transform methods
- Supports getting and setting parameters
- Handles NotTrainableFilterError for filters that don't require training
Usage
from framework3.plugins.filters.classification.svm import ClassifierSVMPlugin
from framework3.utils.skestimator import SkWrapper
import numpy as np
from sklearn.model_selection import GridSearchCV
# Create a sample BaseFilter
class SampleFilter(ClassifierSVMPlugin):
pass
# Create an instance of SkWrapper
wrapper = SkWrapper(SampleFilter, C=1.0, kernel='rbf')
# Use the wrapper with sklearn's GridSearchCV
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5]])
y = np.array([0, 0, 1, 1])
param_grid = {'C': [0.1, 1, 10], 'kernel': ['rbf', 'linear']}
grid_search = GridSearchCV(wrapper, param_grid, cv=3)
grid_search.fit(X, y)
# Make predictions
print(grid_search.predict([[2.5, 3.5]]))
Attributes:
Name | Type | Description |
---|---|---|
_z_clazz |
Type[BaseFilter]
|
The BaseFilter class to be wrapped. |
_model |
BaseFilter
|
An instance of the wrapped BaseFilter class. |
kwargs |
Dict[str, Any]
|
Keyword arguments passed to the wrapped BaseFilter class. |
Methods:
Name | Description |
---|---|
get_zclazz |
Get the name of the wrapped BaseFilter class. |
fit |
Any, y: Any, args, *kwargs) -> 'SkWrapper': Fit the wrapped model to the given data. |
predict |
Any) -> Any: Make predictions using the wrapped model. |
transform |
Any) -> Any: Transform the input data using the wrapped model. |
get_params |
bool = True) -> Dict[str, Any]: Get the parameters of the estimator. |
set_params |
Set the parameters of the estimator. |
Source code in framework3/utils/skestimator.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
|
kwargs = kwargs
instance-attribute
¶
__init__(z_clazz, **kwargs)
¶
Initialize the SkWrapper.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z_clazz
|
Type[BaseFilter]
|
The BaseFilter class to be wrapped. |
required |
**kwargs
|
Any
|
Keyword arguments to be passed to the wrapped BaseFilter class. |
{}
|
Source code in framework3/utils/skestimator.py
fit(x, y, *args, **kwargs)
¶
Fit the wrapped model to the given data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Any
|
The input features. |
required |
y
|
Any
|
The target values. |
required |
*args
|
List[Any]
|
Additional positional arguments. |
()
|
**kwargs
|
Dict[str, Any]
|
Additional keyword arguments. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
SkWrapper |
SkWrapper
|
The fitted estimator. |
Source code in framework3/utils/skestimator.py
get_params(deep=True)
¶
Get the parameters of the estimator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deep
|
bool
|
If True, will return the parameters for this estimator and contained subobjects that are estimators. |
True
|
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dict[str, Any]: Parameter names mapped to their values. |
Source code in framework3/utils/skestimator.py
get_zclazz()
¶
Get the name of the wrapped BaseFilter class.
Returns:
Name | Type | Description |
---|---|---|
str |
str
|
The name of the wrapped BaseFilter class. |
predict(x)
¶
Make predictions using the wrapped model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Any
|
The input features. |
required |
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
The predicted values. |
set_params(**parameters)
¶
Set the parameters of the estimator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**parameters
|
dict
|
Estimator parameters. |
{}
|
Returns:
Type | Description |
---|---|
SkWrapper
|
Estimator instance. |
Source code in framework3/utils/skestimator.py
transform(x)
¶
Transform the input data using the wrapped model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Any
|
The input features. |
required |
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
The transformed data. |