StratifiedKFoldSplitter
labchain.plugins.splitter.stratified_cross_validation_splitter
¶
StratifiedKFoldSplitter
¶
Bases: BaseSplitter
A Stratified K-Fold cross-validation splitter for evaluating classification models.
This class implements Stratified K-Fold cross-validation, which splits the dataset into K folds while preserving the percentage of samples for each class. It is particularly useful for imbalanced datasets.
Key Features
- Preserves label distribution across folds
- Configurable number of splits
- Option to shuffle data before splitting
- Supports custom pipelines for model training and evaluation
- Provides mean loss across all folds
Usage
from framework3.plugins.splitter import StratifiedKFoldSplitter
from framework3.plugins.pipelines.sequential import F3Pipeline
from framework3.base import XYData
import numpy as np
pipeline = F3Pipeline(filters=[...], metrics=[...])
splitter = StratifiedKFoldSplitter(n_splits=5, shuffle=True, random_state=42, pipeline=pipeline)
X = XYData(value=np.random.rand(100, 10))
y = XYData(value=np.random.randint(0, 2, 100))
mean_loss = splitter.fit(X, y)
print(f"Mean loss across folds: {mean_loss}")
Attributes:
| Name | Type | Description |
|---|---|---|
n_splits |
int
|
Number of folds. |
shuffle |
bool
|
Whether to shuffle the data before splitting. |
random_state |
int
|
Random seed for reproducibility. |
pipeline |
BaseFilter | None
|
The pipeline to be used for training and evaluation. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |
n_splits = n_splits
instance-attribute
¶
pipeline = pipeline
instance-attribute
¶
random_state = random_state
instance-attribute
¶
shuffle = shuffle
instance-attribute
¶
__init__(n_splits=5, shuffle=True, random_state=42, pipeline=None)
¶
Initialize the StratifiedKFoldSplitter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_splits
|
int
|
Number of folds. Must be at least 2. |
5
|
shuffle
|
bool
|
Whether to shuffle the data before splitting. |
True
|
random_state
|
int
|
Controls the shuffling applied to the data before splitting. |
42
|
pipeline
|
BaseFilter | None
|
The pipeline used for model training and evaluation. |
None
|
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
evaluate(x_data, y_true, y_pred)
¶
Evaluate the pipeline using the provided data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_data
|
XYData
|
Input features. |
required |
y_true
|
XYData | None
|
Ground truth labels. |
required |
y_pred
|
XYData
|
Predictions from the pipeline. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict[str, Any]: Evaluation metrics. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the pipeline is not fitted. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
fit(x, y)
¶
Perform Stratified K-Fold cross-validation on the given data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
XYData
|
Input features. |
required |
y
|
XYData | None
|
Target labels. |
required |
Returns:
| Type | Description |
|---|---|
Optional[float | dict]
|
Optional[float]: Mean loss across all folds. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If y is None or the pipeline is not set. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
predict(x)
¶
Make predictions using the fitted pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
XYData
|
Input data for prediction. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
XYData |
XYData
|
Predictions from the trained pipeline. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If pipeline is not fitted. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
split(pipeline)
¶
Set the pipeline for the splitter and disable its verbosity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline
|
BaseFilter
|
The pipeline used for training and evaluation. |
required |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
start(x, y, X_)
¶
Start the cross-validation process and optionally make predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
XYData
|
Input training features. |
required |
y
|
Optional[XYData]
|
Target labels. |
required |
X_
|
Optional[XYData]
|
Optional input data for prediction. |
required |
Returns:
| Type | Description |
|---|---|
Optional[XYData]
|
Optional[XYData]: Predictions if X_ is provided, else predictions on training data. |
Raises:
| Type | Description |
|---|---|
Exception
|
If any error occurs during execution. |