Code for the paper:
Balanced MSE for Imbalanced Visual Regression
Jiawei Ren, Mingyuan Zhang, Cunjun Yu, Ziwei Liu
CVPR 2022 (Oral)
Check out our live demo in the Hugging Face 🤗 space!
We provide a minimal working example of Balanced MSE using the BMC implementation on a small-scale dataset, Boston Housing dataset.
The notebook is developed on top of Deep Imbalanced Regression (DIR) Tutorial, we thank the authors for their amazing tutorial!
A code snippet of the Balanced MSE loss is shown below. It is the BMC implementation for one-dimensional imbalanced regression, which does not require any label prior beforehand.
def bmc_loss(pred, target, noise_var):
logits = - (pred - target.T).pow(2) / (2 * noise_var)
loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))
loss = loss * (2 * noise_var).detach() # optional: restore the loss scale
return loss
noise_var
is a hyper-parameter. noise_var
can be optionally optimized in training:
class BMCLoss(_Loss):
def __init__(self, init_noise_sigma):
super(BMCLoss, self).__init__()
self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))
def forward(self, pred, target):
noise_var = self.noise_sigma ** 2
return bmc_loss(pred, target, noise_var)
criterion = BMCLoss(init_noise_sigma)
optimizer.add_param_group({'params': criterion.noise_sigma, 'lr': sigma_lr, 'name': 'noise_sigma'})
Please go into the sub-folder to run experiments.
- IMDB-WIKI-DIR
- NYUD2-DIR
- IHMR (coming soon)
@inproceedings{ren2021bmse,
title={Balanced MSE for Imbalanced Visual Regression},
author={Ren, Jiawei and Zhang, Mingyuan and Yu, Cunjun and Liu, Ziwei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
This work is supported by NTU NAP, MOE AcRF Tier 2 (T2EP20221-0033), the National Research Foundation, Singapore under its AI Singapore Programme, and under the RIE2020 Industry Alignment Fund – Industry Collabo- ration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s).
The code is developed on top of Delving into Deep Imbalanced Regression.