GridSearchCVFilter
framework3.plugins.filters.grid_search
¶
GridSearchCVPlugin
¶
Bases: BaseFilter
A plugin for performing hyperparameter tuning on BaseFilter objects using scikit-learn's GridSearchCV.
This plugin automates the process of finding optimal hyperparameters for a given BaseFilter by evaluating different combinations of parameters through cross-validation.
Key Features
- Integrates scikit-learn's GridSearchCV with framework3's BaseFilter
- Supports hyperparameter tuning for any BaseFilter
- Allows specification of parameter grid, scoring metric, and cross-validation strategy
- Provides methods for fitting the model and making predictions with the best found parameters
Usage
The GridSearchCVPlugin can be used to perform hyperparameter tuning on a BaseFilter:
from framework3.plugins.filters.clasification.svm import ClassifierSVMPlugin
from framework3.base.base_types import XYData
import numpy as np
# Create sample data
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5]])
y = np.array([0, 0, 1, 1])
X_data = XYData(_hash='X_data', _path='/tmp', _value=X)
y_data = XYData(_hash='y_data', _path='/tmp', _value=y)
# Define the parameter grid
param_grid = {
'C': [0.1, 1, 10],
'kernel': ['linear', 'rbf'],
'gamma': ['scale', 'auto']
}
# Create the GridSearchCVPlugin
grid_search = GridSearchCVPlugin(
filterx=ClassifierSVMPlugin,
param_grid=param_grid,
scoring='accuracy',
cv=3
)
# Fit the grid search
grid_search.fit(X_data, y_data)
# Make predictions
X_test = XYData(_hash='X_test', _path='/tmp', _value=np.array([[2.5, 3.5]]))
predictions = grid_search.predict(X_test)
print(predictions.value)
# Access the best parameters
print(grid_search._clf.best_params_)
Attributes:
Name | Type | Description |
---|---|---|
_clf |
GridSearchCV
|
The GridSearchCV object used for hyperparameter tuning. |
Methods:
Name | Description |
---|---|
fit |
XYData, y: XYData): Fit the GridSearchCV object to the given data. |
predict |
XYData) -> XYData: Make predictions using the best estimator found by GridSearchCV. |
Note
This plugin uses scikit-learn's GridSearchCV, which may have its own dependencies and requirements. Ensure that scikit-learn is properly installed and compatible with your environment.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
|
__init__(filterx, param_grid, scoring, cv=2)
¶
Initialize a new GridSearchCVPlugin instance.
This constructor sets up the GridSearchCVPlugin with the specified BaseFilter, parameter grid, scoring metric, and cross-validation strategy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filterx
|
Type[BaseFilter]
|
The BaseFilter class to be tuned. |
required |
param_grid
|
Dict[str, Any]
|
Dictionary with parameters names as keys and lists of parameter settings to try as values. |
required |
scoring
|
str
|
Strategy to evaluate the performance of the cross-validated model on the test set. |
required |
cv
|
int
|
Determines the cross-validation splitting strategy. Defaults to 2. |
2
|
Note
The GridSearchCV object is initialized with a Pipeline containing the specified BaseFilter wrapped in an SkWrapper to ensure compatibility with scikit-learn's API.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
fit(x, y)
¶
Fit the GridSearchCV object to the given data.
This method performs the grid search over the specified parameter grid, fitting the model with different parameter combinations and selecting the best one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features. |
required |
y
|
XYData
|
The target values. |
required |
Note
This method modifies the internal state of the GridSearchCV object, storing the best parameters and the corresponding fitted model.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
predict(x)
¶
Make predictions using the best estimator found by GridSearchCV.
This method uses the best model found during the grid search to make predictions on new data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features. |
required |
Returns:
Type | Description |
---|---|
XYData
|
The predicted values wrapped in an XYData object. |
Note
The predictions are wrapped in an XYData object for consistency with the framework.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
cv_grid_search
¶
__all__ = ['GridSearchCVPlugin']
module-attribute
¶
GridSearchCVPlugin
¶
Bases: BaseFilter
A plugin for performing hyperparameter tuning on BaseFilter objects using scikit-learn's GridSearchCV.
This plugin automates the process of finding optimal hyperparameters for a given BaseFilter by evaluating different combinations of parameters through cross-validation.
Key Features
- Integrates scikit-learn's GridSearchCV with framework3's BaseFilter
- Supports hyperparameter tuning for any BaseFilter
- Allows specification of parameter grid, scoring metric, and cross-validation strategy
- Provides methods for fitting the model and making predictions with the best found parameters
Usage
The GridSearchCVPlugin can be used to perform hyperparameter tuning on a BaseFilter:
from framework3.plugins.filters.clasification.svm import ClassifierSVMPlugin
from framework3.base.base_types import XYData
import numpy as np
# Create sample data
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5]])
y = np.array([0, 0, 1, 1])
X_data = XYData(_hash='X_data', _path='/tmp', _value=X)
y_data = XYData(_hash='y_data', _path='/tmp', _value=y)
# Define the parameter grid
param_grid = {
'C': [0.1, 1, 10],
'kernel': ['linear', 'rbf'],
'gamma': ['scale', 'auto']
}
# Create the GridSearchCVPlugin
grid_search = GridSearchCVPlugin(
filterx=ClassifierSVMPlugin,
param_grid=param_grid,
scoring='accuracy',
cv=3
)
# Fit the grid search
grid_search.fit(X_data, y_data)
# Make predictions
X_test = XYData(_hash='X_test', _path='/tmp', _value=np.array([[2.5, 3.5]]))
predictions = grid_search.predict(X_test)
print(predictions.value)
# Access the best parameters
print(grid_search._clf.best_params_)
Attributes:
Name | Type | Description |
---|---|---|
_clf |
GridSearchCV
|
The GridSearchCV object used for hyperparameter tuning. |
Methods:
Name | Description |
---|---|
fit |
XYData, y: XYData): Fit the GridSearchCV object to the given data. |
predict |
XYData) -> XYData: Make predictions using the best estimator found by GridSearchCV. |
Note
This plugin uses scikit-learn's GridSearchCV, which may have its own dependencies and requirements. Ensure that scikit-learn is properly installed and compatible with your environment.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
|
__init__(filterx, param_grid, scoring, cv=2)
¶
Initialize a new GridSearchCVPlugin instance.
This constructor sets up the GridSearchCVPlugin with the specified BaseFilter, parameter grid, scoring metric, and cross-validation strategy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filterx
|
Type[BaseFilter]
|
The BaseFilter class to be tuned. |
required |
param_grid
|
Dict[str, Any]
|
Dictionary with parameters names as keys and lists of parameter settings to try as values. |
required |
scoring
|
str
|
Strategy to evaluate the performance of the cross-validated model on the test set. |
required |
cv
|
int
|
Determines the cross-validation splitting strategy. Defaults to 2. |
2
|
Note
The GridSearchCV object is initialized with a Pipeline containing the specified BaseFilter wrapped in an SkWrapper to ensure compatibility with scikit-learn's API.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
fit(x, y)
¶
Fit the GridSearchCV object to the given data.
This method performs the grid search over the specified parameter grid, fitting the model with different parameter combinations and selecting the best one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features. |
required |
y
|
XYData
|
The target values. |
required |
Note
This method modifies the internal state of the GridSearchCV object, storing the best parameters and the corresponding fitted model.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
predict(x)
¶
Make predictions using the best estimator found by GridSearchCV.
This method uses the best model found during the grid search to make predictions on new data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features. |
required |
Returns:
Type | Description |
---|---|
XYData
|
The predicted values wrapped in an XYData object. |
Note
The predictions are wrapped in an XYData object for consistency with the framework.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
SkFilterWrapper
¶
Bases: BaseEstimator
A wrapper class for BaseFilter that implements scikit-learn's BaseEstimator interface.
This class enables BaseFilter objects to be used with scikit-learn's GridSearchCV, bridging the gap between framework3's filters and scikit-learn's estimator interface.
Key Features
- Wraps BaseFilter objects to comply with scikit-learn's BaseEstimator interface
- Allows use of framework3 filters in scikit-learn's GridSearchCV
- Provides methods for fitting, predicting, and parameter management
Usage
The SkFilterWrapper can be used to wrap a BaseFilter for use with GridSearchCV:
from framework3.plugins.filters.clasification.svm import ClassifierSVMPlugin
import numpy as np
# Set the class to be wrapped
SkFilterWrapper.z_clazz = ClassifierSVMPlugin
# Create an instance of SkFilterWrapper
wrapper = SkFilterWrapper(C=1.0, kernel='rbf')
# Use the wrapper with sklearn's GridSearchCV
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5]])
y = np.array([0, 0, 1, 1])
wrapper.fit(X, y)
print(wrapper.predict([[2.5, 3.5]]))
Attributes:
Name | Type | Description |
---|---|---|
z_clazz |
Type[BaseFilter]
|
The BaseFilter class to be wrapped. |
_model |
BaseFilter
|
The instance of the wrapped BaseFilter. |
kwargs |
Dict[str, Any]
|
The keyword arguments used to initialize the wrapped BaseFilter. |
Methods:
Name | Description |
---|---|
fit |
Fit the wrapped model to the given data. |
predict |
Make predictions using the wrapped model. |
get_params |
Get the parameters of the estimator. |
set_params |
Set the parameters of the estimator. |
Note
This wrapper is specifically designed to work with framework3's BaseFilter and scikit-learn's GridSearchCV. Ensure that the wrapped BaseFilter is compatible with the data and operations you intend to perform.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
|
kwargs = kwargs
instance-attribute
¶
z_clazz
instance-attribute
¶
__init__(clazz, **kwargs)
¶
Initialize a new SkFilterWrapper instance.
This constructor creates an instance of the specified BaseFilter class with the given parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clazz
|
Type[BaseFilter]
|
The BaseFilter class to be instantiated. |
required |
**kwargs
|
Dict[str, Any]
|
Keyword arguments to be passed to the BaseFilter constructor. |
{}
|
Note
The initialized BaseFilter instance is stored in self._model, and the kwargs are stored for later use in get_params and set_params.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
fit(x, y, *args, **kwargs)
¶
Fit the wrapped model to the given data.
This method wraps the input data in XYData objects and calls the fit method of the wrapped BaseFilter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
TxyData
|
The input features. |
required |
y
|
TxyData
|
The target values. |
required |
*args
|
List[Any]
|
Additional positional arguments (not used). |
()
|
**kwargs
|
Dict[str, Any]
|
Additional keyword arguments (not used). |
{}
|
Returns:
Name | Type | Description |
---|---|---|
self |
SkFilterWrapper
|
The fitted estimator. |
Note
This method modifies the internal state of the wrapped model.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
get_params(deep=True)
¶
Get the parameters of the estimator.
This method returns the kwargs used to initialize the wrapped BaseFilter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deep
|
bool
|
If True, will return the parameters for this estimator and contained subobjects that are estimators. Not used in this implementation. |
True
|
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Parameter names mapped to their values. |
Note
The 'deep' parameter is included for compatibility with scikit-learn, but it doesn't affect the output in this implementation.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
predict(x)
¶
Make predictions using the wrapped model.
This method wraps the input data in an XYData object, calls the predict method of the wrapped BaseFilter, and returns the raw value from the resulting XYData.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
TxyData
|
The input features. |
required |
Returns:
Type | Description |
---|---|
TxyData
|
The predicted values. |
Note
The return value is the raw prediction, not wrapped in an XYData object.
Source code in framework3/plugins/filters/grid_search/cv_grid_search.py
set_params(**parameters)
¶
Set the parameters of the estimator.
This method creates a new instance of the wrapped BaseFilter with the specified parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**parameters
|
Dict[str, Any]
|
Estimator parameters. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
self |
SkFilterWrapper
|
Estimator instance. |
Note
This method replaces the existing wrapped model with a new instance.