Skip to content

Implement temperature scaling for (multi-class) calibration #28574

Open
@dholzmueller

Description

@dholzmueller

Describe the workflow you want to enable

It would be great to have temperature scaling available as a post-hoc calibration method for binary and multi-class classifiers, for example in CalibratedClassifierCV.

Describe your proposed solution

Temperature scaling is a simple, efficient, and very popular post-hoc calibration method that also naturally supports the multi-class classification setting. It has been proposed in Guo et al. (2017) with >5000 citations, so it meets the inclusion criterion: http://proceedings.mlr.press/v70/guo17a.html
It also does not affect rank-based metrics (if the temperature is restricted to positive values) unlike isotonic regression (#16321). Moreover, it avoids the infinite-log-loss problems of isotonic regression.
Temperature scaling has been discussed in #21785
I experimented with different post-hoc calibration methods on 71 medium-sized (2K-50K samples) tabular classification data sets. For NNs and XGBoost, temperature scaling is competitive with isotonic regression and considerably better than Platt scaling (if Platt scaling is applied to probabilities, as implemented in scikit-learn, and not logits). For AUC, it is considerably better than isotonic regression.

Here is a simple implementation using PyTorch (can be adapted to numpy). It is derived from the popular but no longer maintained implementation at https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py
with the following changes:

  • using inverse temperatures to prevent division by zero errors
  • using 50 optimizer steps instead of a single one (seemingly an error in the mentioned repo). (The original paper mentions that 10 CG iterations should be enough, here it is 50 L-BFGS iterations.)
  • accepting probabilities as provided by many scikit-learn estimators using predict_proba(). The code converts probabilities to logits using log(probs + 1e-10). While the logits are only determined up to a constant shift, the choice of the constant does not affect the result of temperature scaling.
import torch
import torch.nn as nn
import numpy as np
from sklearn.base import BaseEstimator

class InverseTemperatureScalingCalibrator(BaseEstimator):
    # following https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py
    def _get_logits(self, X):
        X = X + 1e-10
        X /= np.sum(X, axis=-1, keepdims=True)
        return torch.as_tensor(np.log(X), dtype=torch.float32)

    def fit(self, X, y):
        # X should be the probabilities as output by predict_proba()
        logits = self._get_logits(X)
        labels = torch.as_tensor(y)
        self.inv_temperature_ = nn.Parameter(torch.ones(1) / 1.5)
        criterion = nn.CrossEntropyLoss()

        optimizer = torch.optim.LBFGS([self.inv_temperature_], lr=0.01, max_iter=50)

        def eval():
            optimizer.zero_grad()
            y_pred = logits * self.inv_temperature_[:, None]
            loss = criterion(y_pred, labels)
            loss.backward()
            return loss

        for i in range(50):
            optimizer.step(eval)

        print(f'Optimal temperature: {(1./self.inv_temperature_).item():g}')
        return self

    def predict_proba(self, X):
        # X should be the probabilities as output by predict_proba()
        logits = self._get_logits(X)
        with torch.no_grad():
            y_pred = logits * self.inv_temperature_[:, None]
            return torch.softmax(y_pred, dim=-1).detach().numpy()

Describe alternatives you've considered, if relevant

Centered isotonic regression (#21454) is less popular and does not fully solve the problem of affecting rank-based metrics.
Beta-calibration (#25552) seems very similar or even partially identical but is less well-cited, and only formulated for binary classification.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions