Skip to content

Implementing Temperature Scaling #29517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
f128648
Added link for plot_adaboost_multiclass example
virchan Dec 8, 2023
6fb31ea
Moved the example link from the example itself back to the doc string
virchan Dec 10, 2023
dbbf320
Merge branch 'main' of https://github.com/virchan/scikit-learn
virchan Dec 10, 2023
1b9b29e
Merge branch 'scikit-learn:main' into main
virchan Jan 26, 2024
a2c7468
Reworded the example reference of AdaBoost in the
virchan Jan 26, 2024
c2fdfc0
Merge branch 'main' of https://github.com/virchan/scikit-learn
virchan Jan 26, 2024
b46f103
Merge branch 'scikit-learn:main' into main
virchan Feb 9, 2024
5742a52
Merge branch 'scikit-learn:main' into main
virchan Feb 26, 2024
eb50b06
- Added the Multi-class AdaBoosted Decision Trees example to the Deci…
virchan Feb 28, 2024
1fb3a20
Merge branch 'main' of https://github.com/virchan/scikit-learn
virchan Feb 28, 2024
a188d7e
Merge branch 'main' of https://github.com/virchan/scikit-learn
virchan Feb 28, 2024
0f38e66
Merge branch 'main' into main
virchan Mar 1, 2024
46318a7
Reformatted doc-strings to meet the ruff requirement
virchan Mar 1, 2024
07e66e2
Merge branch 'main' of https://github.com/virchan/scikit-learn
virchan Mar 1, 2024
752a8e8
Empty commit for test re-run.
virchan Mar 1, 2024
577209b
Merge branch 'scikit-learn:main' into main
virchan Mar 4, 2024
ff6824d
Merge branch 'scikit-learn:main' into main
virchan Mar 6, 2024
6d866d4
Removed example links from two files
virchan Mar 6, 2024
c2225fe
Merge branch 'main' of https://github.com/virchan/scikit-learn
virchan Mar 6, 2024
23e6ef1
Empty commit for checks re-run
virchan Mar 6, 2024
03aa44a
Removed an empty line for checks re-run.
virchan Mar 6, 2024
350c6e4
Merge branch 'scikit-learn:main' into main
virchan Mar 7, 2024
3cd7664
Merge branch 'scikit-learn:main' into main
virchan Mar 8, 2024
0e42eaa
Merge branch 'scikit-learn:main' into main
virchan Mar 11, 2024
215a3db
Merge branch 'scikit-learn:main' into main
virchan Mar 15, 2024
6a54b8a
Merge branch 'scikit-learn:main' into main
virchan Apr 10, 2024
5f4e097
Merge branch 'scikit-learn:main' into main
virchan Apr 29, 2024
3b6eb34
Merge branch 'scikit-learn:main' into main
virchan May 3, 2024
45ab127
Merge branch 'scikit-learn:main' into main
virchan May 6, 2024
88d4500
Created the `sklearn/calibration_temperature.py` to contain all work …
virchan May 6, 2024
5f44367
Merge remote-tracking branch 'origin/main'
virchan May 6, 2024
6747608
Merge branch 'scikit-learn:main' into main
virchan May 8, 2024
7e2a444
Added the `_TemperatureScaling` class and associated helper functions.
virchan May 10, 2024
6b775a9
Merge branch 'scikit-learn:main' into main
virchan May 13, 2024
eaa0e41
Merge branch 'scikit-learn:main' into main
virchan May 14, 2024
e087c20
Merge branch 'scikit-learn:main' into main
virchan May 21, 2024
3fdcfcf
Merge branch 'scikit-learn:main' into main
virchan May 21, 2024
7d66527
Merge branch 'scikit-learn:main' into main
virchan May 22, 2024
29db0e3
Merge branch 'scikit-learn:main' into main
virchan May 24, 2024
5588049
Merge branch 'scikit-learn:main' into main
virchan May 27, 2024
bc72e7d
Merge branch 'scikit-learn:main' into main
virchan May 28, 2024
51b787b
Merge branch 'scikit-learn:main' into main
virchan May 29, 2024
d6badaf
Merge branch 'scikit-learn:main' into main
virchan May 31, 2024
eb4ccdf
Merge branch 'scikit-learn:main' into main
virchan Jun 2, 2024
de57d4e
Merge branch 'scikit-learn:main' into main
virchan Jun 5, 2024
974c40b
Merge branch 'scikit-learn:main' into main
virchan Jun 6, 2024
a2f5e4c
Merge branch 'scikit-learn:main' into main
virchan Jun 8, 2024
a643556
Merge branch 'scikit-learn:main' into main
virchan Jun 12, 2024
18d959f
Merge branch 'scikit-learn:main' into main
virchan Jun 14, 2024
cb2fc41
Merge branch 'scikit-learn:main' into main
virchan Jun 17, 2024
7acc779
- Converted variables into lowercase to reduce warning messages.
virchan Jun 18, 2024
2028b31
Merge remote-tracking branch 'origin/main'
virchan Jun 18, 2024
25a1bf2
- Converted variables into lowercase to reduce warning messages.
virchan Jun 18, 2024
42c4aa8
Merge branch 'scikit-learn:main' into main
virchan Jun 19, 2024
26f458d
Added doc-strings to temperature-scaling-related functions.
virchan Jun 19, 2024
97a5fca
Merge branch 'scikit-learn:main' into main
virchan Jun 21, 2024
8b0e7ae
Merge branch 'scikit-learn:main' into main
virchan Jun 21, 2024
2e47492
Merge branch 'scikit-learn:main' into main
virchan Jun 26, 2024
3d7c7a0
Merge branch 'scikit-learn:main' into main
virchan Jul 3, 2024
d0e58c7
Merge branch 'scikit-learn:main' into main
virchan Jul 8, 2024
6b259db
Merge branch 'scikit-learn:main' into main
virchan Jul 9, 2024
1fe2056
Merge branch 'scikit-learn:main' into main
virchan Jul 12, 2024
6160ee1
Modified the `.fit()` method of temperature scaling. Now it can handl…
virchan Jul 12, 2024
bd7576f
Merge branch 'scikit-learn:main' into main
virchan Jul 15, 2024
f5f7f77
Merge branch 'scikit-learn:main' into main
virchan Jul 17, 2024
67afee5
1. Modified the `_TemperatureScaling` class to adept `sample_weight` …
virchan Jul 17, 2024
a857923
Merge branch 'scikit-learn:main' into main
virchan Jul 18, 2024
e72adfe
Modified `_temperature_scaling_test.py` and `calibration_temperature.…
virchan Jul 18, 2024
5552e23
Merge branch 'scikit-learn:main' into main
virchan Jul 22, 2024
06b54e5
Merge branch 'scikit-learn:main' into main
virchan Jul 23, 2024
9390241
Merge branch 'scikit-learn:main' into main
virchan Jul 27, 2024
3ed6eed
Merge branch 'scikit-learn:main' into main
virchan Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions sklearn/_temperature_scaling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
'''
This file is created to test if the custom 'TemperatureScaling' class runs properly,
and serves as proof of work for the changes made to the scikit-learn repository.

The file also includes examples related to developing a temperature scaling method
for probability calibration in multi-class classification.


References:
-----------
.. [1] https://github.com/scikit-learn/scikit-learn/issues/28574. Original issue
on Github.

.. [2] On Calibration of Modern Neural Networks,
C. Guo, G. Pleiss, Y. Sun & K. Q. Weinberger, ICML 2017
'''

from sklearn.calibration_temperature import CalibratedClassifierCV_test
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

# We demonstrate with the Iris dataset, because
# it is small, multi-class, and self-provided.
X, y = datasets.load_iris(return_X_y=True)
X_train, X_calib, y_train, y_calib = train_test_split(X, y)

# Load the following classifiers for testing
# - Support vector classifier
# - Logistic regressor
# - Decision tree classifier
SV_classifier: SVC = SVC(probability=False)
Logistic_classifier: LogisticRegression = LogisticRegression()
Tree_classifier: DecisionTreeClassifier = DecisionTreeClassifier()

# Initiate the temperature scaling calibrators for the classifiers
SVC_scaled: CalibratedClassifierCV_test = CalibratedClassifierCV_test(SV_classifier,
cv=3,
method='temperature'
)
Logistic_scaled: CalibratedClassifierCV_test = CalibratedClassifierCV_test(Logistic_classifier,
cv=7,
method='temperature'
)
Tree_scaled: CalibratedClassifierCV_test = CalibratedClassifierCV_test(Tree_classifier,
cv=3,
method='temperature'
)

# Calibrate the classifiers with temperature scaling
# The calibrators are trained with the output of
# `decision_function` for the support vector classifier
# and logistic regression, while they are trained with
# `predict_proba` for the decision tree classifier.
SVC_scaled.fit(X_train,y_train)
Logistic_scaled.fit(X_train,y_train)
Tree_scaled.fit(X_train,y_train)

print("Optimal Temperatures For Each Classifiers")
print(f"{SVC_scaled.calibrated_classifiers_[0].calibrators[0].T_=}")
print(f"{Logistic_scaled.calibrated_classifiers_[0].calibrators[0].T_=}")
print(f"{Tree_scaled.calibrated_classifiers_[0].calibrators[0].T_=}")

print('\n')
print("Printing calibrated probabilities...")
print(f"{SVC_scaled.predict_proba(X_calib)=}")
print(f"{Logistic_scaled.predict_proba(X_calib)=}")
print(f"{Tree_scaled.predict_proba(X_calib)=}")
print(f"{y_calib=}")
Loading
Loading