Note
Go to the end to download the full example code. or to run this example in your browser via JupyterLite or Binder
Probability Calibration for 3-class classification#
This example illustrates how sigmoid calibration changes predicted probabilities for a 3-class classification problem. Illustrated is the standard 2-simplex, where the three corners correspond to the three classes. Arrows point from the probability vectors predicted by an uncalibrated classifier to the probability vectors predicted by the same classifier after sigmoid calibration on a hold-out validation set. Colors indicate the true class of an instance (red: class 1, green: class 2, blue: class 3).
Data#
Below, we generate a classification dataset with 2000 samples, 2 features and 3 target classes. We then split the data as follows:
train: 600 samples (for training the classifier)
valid: 400 samples (for calibrating predicted probabilities)
test: 1000 samples
Note that we also create X_train_valid
and y_train_valid
, which consists
of both the train and valid subsets. This is used when we only want to train
the classifier but not calibrate the predicted probabilities.
# Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause importnumpyasnp fromsklearn.datasetsimport make_blobs np.random.seed (0) X, y = make_blobs ( n_samples=2000, n_features=2, centers=3, random_state=42, cluster_std=5.0 ) X_train, y_train = X[:600], y[:600] X_valid, y_valid = X[600:1000], y[600:1000] X_train_valid, y_train_valid = X[:1000], y[:1000] X_test, y_test = X[1000:], y[1000:]
Fitting and calibration#
First, we will train a RandomForestClassifier
with 25 base estimators (trees) on the concatenated train and validation
data (1000 samples). This is the uncalibrated classifier.
fromsklearn.ensembleimport RandomForestClassifier clf = RandomForestClassifier (n_estimators=25) clf.fit(X_train_valid, y_train_valid)
RandomForestClassifier(n_estimators=25)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.