diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 6fa0d53c8a44..bdbaaed65119 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1287,7 +1287,9 @@ class probabilities only when a single class label per minibatch item is too res :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The target data type is required to be long when using class indices. If containing class probabilities, the target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target - data type is required to be float when using class probabilities. + data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce + probability constraints on the class probabilities and that it is the user's responsibility to ensure + ``target`` contains valid probability distributions (see below examples section for more details). - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. @@ -1314,6 +1316,51 @@ class probabilities only when a single class label per minibatch item is too res >>> target = torch.randn(3, 5).softmax(dim=1) >>> output = loss(input, target) >>> output.backward() + + .. note:: + When ``target`` contains class probabilities, it should consist of soft labels—that is, + each ``target`` entry should represent a probability distribution over the possible classes for a given data sample, + with individual probabilities between ``[0,1]`` and the total distribution summing to 1. + This is why the :func:`softmax()` function is applied to the ``target`` in the class probabilities example above. + + PyTorch does not validate whether the values provided in ``target`` lie in the range ``[0,1]`` + or whether the distribution of each data sample sums to ``1``. + No warning will be raised and it is the user's responsibility + to ensure that ``target`` contains valid probability distributions. + Providing arbitrary values may yield misleading loss values and unstable gradients during training. + + Examples: + + >>> # Example of target with incorrectly specified class probabilities + >>> loss = nn.CrossEntropyLoss() + >>> torch.manual_seed(283) + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> # Provided target class probabilities are not in range [0,1] + >>> target + tensor([[ 0.7105, 0.4446, 2.0297, 0.2671, -0.6075], + [-1.0496, -0.2753, -0.3586, 0.9270, 1.0027], + [ 0.7551, 0.1003, 1.3468, -0.3581, -0.9569]]) + >>> # Provided target class probabilities do not sum to 1 + >>> target.sum(axis=1) + tensor([2.8444, 0.2462, 0.8873]) + >>> # No error message and possible misleading loss value + >>> loss(input, target).item() + 4.6379876136779785 + >>> + >>> # Example of target with correctly specified class probabilities + >>> # Use .softmax() to ensure true probability distribution + >>> target_new = target.softmax(dim=1) + >>> # New target class probabilities all in range [0,1] + >>> target_new + tensor([[0.1559, 0.1195, 0.5830, 0.1000, 0.0417], + [0.0496, 0.1075, 0.0990, 0.3579, 0.3860], + [0.2607, 0.1355, 0.4711, 0.0856, 0.0471]]) + >>> # New target class probabilities sum to 1 + >>> target_new.sum(axis=1) + tensor([1.0000, 1.0000, 1.0000]) + >>> loss(input, target_new).item() + 2.55349063873291 """ __constants__ = ["ignore_index", "reduction", "label_smoothing"]