KFoldSplitter
framework3.plugins.splitter.cross_validation_splitter
¶
KFoldSplitter
¶
Bases: BaseSplitter
A K-Fold cross-validation splitter for evaluating machine learning models.
This class implements K-Fold cross-validation, which splits the dataset into K equally sized folds. The model is trained on K-1 folds and validated on the remaining fold. This process is repeated K times, with each fold serving as the validation set once.
Key Features
- Configurable number of splits
- Option to shuffle data before splitting
- Supports custom pipelines for model training and evaluation
- Provides mean loss across all folds
Usage
from framework3.plugins.splitter import KFoldSplitter
from framework3.plugins.pipelines.sequential import F3Pipeline
from framework3.base import XYData
import numpy as np
# Create a dummy pipeline
pipeline = F3Pipeline(filters=[...], metrics=[...])
# Create the KFoldSplitter
splitter = KFoldSplitter(n_splits=5, shuffle=True, random_state=42, pipeline=pipeline)
# Prepare some dummy data
X = XYData(value=np.random.rand(100, 10))
y = XYData(value=np.random.randint(0, 2, 100))
# Fit and evaluate the model using cross-validation
mean_loss = splitter.fit(X, y)
print(f"Mean loss across folds: {mean_loss}")
# Make predictions on new data
X_new = XYData(value=np.random.rand(20, 10))
predictions = splitter.predict(X_new)
Attributes:
Name | Type | Description |
---|---|---|
n_splits |
int
|
Number of folds. Must be at least 2. |
shuffle |
bool
|
Whether to shuffle the data before splitting. |
random_state |
int
|
Controls the shuffling applied to the data before applying the split. |
pipeline |
BaseFilter | None
|
The pipeline to be used for training and evaluation. |
Methods:
Name | Description |
---|---|
split |
BaseFilter): Set the pipeline for the splitter. |
fit |
XYData, y: XYData | None) -> Optional[float]: Perform K-Fold cross-validation. |
predict |
XYData) -> XYData: Make predictions using the fitted pipeline. |
evaluate |
XYData, y_true: XYData | None, y_pred: XYData) -> Dict[str, Any]: Evaluate the pipeline using the last fold. |
start |
XYData, y: Optional[XYData], X_: Optional[XYData]) -> Optional[XYData]: Start the cross-validation process and optionally make predictions. |
Source code in framework3/plugins/splitter/cross_validation_splitter.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
|
n_splits = n_splits
instance-attribute
¶
pipeline = pipeline
instance-attribute
¶
random_state = random_state
instance-attribute
¶
shuffle = shuffle
instance-attribute
¶
__init__(n_splits=5, shuffle=True, random_state=42, pipeline=None)
¶
Initialize the KFoldSplitter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_splits
|
int
|
Number of folds. Must be at least 2. Defaults to 5. |
5
|
shuffle
|
bool
|
Whether to shuffle the data before splitting. Defaults to True. |
True
|
random_state
|
int
|
Controls the shuffling applied to the data before applying the split. Defaults to 42. |
42
|
pipeline
|
BaseFilter | None
|
The pipeline to be used for training and evaluation. Defaults to None. |
None
|
Source code in framework3/plugins/splitter/cross_validation_splitter.py
evaluate(x_data, y_true, y_pred)
¶
Evaluate the pipeline using the provided data.
This method uses the pipeline's evaluate method to assess its performance on the given data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_data
|
XYData
|
The input features. |
required |
y_true
|
XYData | None
|
The true target values. |
required |
y_pred
|
XYData
|
The predicted target values. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dict[str, Any]: A dictionary containing the evaluation metrics. |
Raises:
Type | Description |
---|---|
ValueError
|
If the pipeline has not been fitted. |
Source code in framework3/plugins/splitter/cross_validation_splitter.py
fit(x, y)
¶
Perform K-Fold cross-validation on the given data.
This method splits the data into K folds, trains the pipeline on K-1 folds, and evaluates it on the remaining fold. This process is repeated K times.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features. |
required |
y
|
XYData | None
|
The target values. |
required |
Returns:
Type | Description |
---|---|
Optional[float]
|
Optional[float]: The mean loss across all folds, or None if no losses were calculated. |
Raises:
Type | Description |
---|---|
ValueError
|
If y is None or if the pipeline is not set. |
Source code in framework3/plugins/splitter/cross_validation_splitter.py
predict(x)
¶
Make predictions using the fitted pipeline.
This method uses the pipeline that was fitted during cross-validation to make predictions on new data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features for prediction. |
required |
Returns:
Name | Type | Description |
---|---|---|
XYData |
XYData
|
The predictions made by the pipeline. |
Raises:
Type | Description |
---|---|
ValueError
|
If the pipeline has not been fitted. |
Source code in framework3/plugins/splitter/cross_validation_splitter.py
split(pipeline)
¶
Set the pipeline for the splitter and disable its verbosity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline
|
BaseFilter
|
The pipeline to be used for training and evaluation. |
required |
Source code in framework3/plugins/splitter/cross_validation_splitter.py
start(x, y, X_)
¶
Start the cross-validation process and optionally make predictions.
This method performs cross-validation by fitting the model and then makes predictions if X_ is provided.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
XYData
|
The input features for training. |
required |
y
|
Optional[XYData]
|
The target values for training. |
required |
X_
|
Optional[XYData]
|
The input features for prediction, if different from x. |
required |
Returns:
Type | Description |
---|---|
Optional[XYData]
|
Optional[XYData]: Prediction results if X_ is provided, else None. |
Raises:
Type | Description |
---|---|
Exception
|
If an error occurs during the process. |