StratifiedKFoldSplitter
labchain.plugins.splitter.stratified_cross_validation_splitter
¶
StratifiedKFoldSplitter
¶
Bases: BaseSplitter
A Stratified K-Fold cross-validation splitter for evaluating classification models.
This class implements Stratified K-Fold cross-validation, which splits the dataset into K folds while preserving the percentage of samples for each class. It is particularly useful for imbalanced datasets.
Key Features
- Preserves label distribution across folds
- Configurable number of splits
- Option to shuffle data before splitting
- Supports custom pipelines for model training and evaluation
- Provides mean loss across all folds
Usage
from framework3.plugins.splitter import StratifiedKFoldSplitter
from framework3.plugins.pipelines.sequential import F3Pipeline
from framework3.base import XYData
import numpy as np
pipeline = F3Pipeline(filters=[...], metrics=[...])
splitter = StratifiedKFoldSplitter(n_splits=5, shuffle=True, random_state=42, pipeline=pipeline)
X = XYData(value=np.random.rand(100, 10))
y = XYData(value=np.random.randint(0, 2, 100))
mean_loss = splitter.fit(X, y)
print(f"Mean loss across folds: {mean_loss}")
Attributes:
| Name | Type | Description |
|---|---|---|
n_splits |
int
|
Number of folds. |
shuffle |
bool
|
Whether to shuffle the data before splitting. |
random_state |
int
|
Random seed for reproducibility. |
pipeline |
BaseFilter | None
|
The pipeline to be used for training and evaluation. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
| |
n_splits = n_splits
instance-attribute
¶
pipeline = pipeline
instance-attribute
¶
random_state = random_state
instance-attribute
¶
shuffle = shuffle
instance-attribute
¶
__init__(n_splits=5, shuffle=True, random_state=42, pipeline=None)
¶
Initialize the StratifiedKFoldSplitter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_splits
|
int
|
Number of folds. Must be at least 2. |
5
|
shuffle
|
bool
|
Whether to shuffle the data before splitting. |
True
|
random_state
|
int
|
Controls the shuffling applied to the data before splitting. |
42
|
pipeline
|
BaseFilter | None
|
The pipeline used for model training and evaluation. |
None
|
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
evaluate(x_data, y_true, y_pred)
¶
Evaluate the pipeline using the provided data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_data
|
XYData
|
Input features. |
required |
y_true
|
XYData | None
|
Ground truth labels. |
required |
y_pred
|
XYData
|
Predictions from the pipeline. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict[str, Any]: Evaluation metrics. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the pipeline is not fitted. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
fit(x, y)
¶
Perform Stratified K-Fold cross-validation on the given data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
XYData
|
Input features. |
required |
y
|
XYData | None
|
Target labels. |
required |
Returns:
| Type | Description |
|---|---|
Optional[float | dict]
|
Optional[float]: Mean loss across all folds. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If y is None or the pipeline is not set. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
predict(x)
¶
Make predictions using the fitted pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
XYData
|
Input data for prediction. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
XYData |
XYData
|
Predictions from the trained pipeline. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If pipeline is not fitted. |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
split(pipeline)
¶
Set the pipeline for the splitter and disable its verbosity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline
|
BaseFilter
|
The pipeline used for training and evaluation. |
required |
Source code in labchain/plugins/splitter/stratified_cross_validation_splitter.py
start(x, y, X_)
¶
Start the cross-validation process and optionally make predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
XYData
|
Input training features. |
required |
y
|
Optional[XYData]
|
Target labels. |
required |
X_
|
Optional[XYData]
|
Optional input data for prediction. |
required |
Returns:
| Type | Description |
|---|---|
Optional[XYData]
|
Optional[XYData]: Predictions if X_ is provided, else predictions on training data. |
Raises:
| Type | Description |
|---|---|
Exception
|
If any error occurs during execution. |