Skip to content

DOC: update CrossEntropyLoss with note and example of incorrect target specification #155649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
49 changes: 48 additions & 1 deletion torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,9 @@ class probabilities only when a single class label per minibatch item is too res
:math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The
target data type is required to be long when using class indices. If containing class probabilities, the
target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target
data type is required to be float when using class probabilities.
data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce
probability constraints on the class probabilities and that it is the user's responsibility to ensure
``target`` contains valid probability distributions (see below examples section for more details).
- Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar.

Expand All @@ -1314,6 +1316,51 @@ class probabilities only when a single class label per minibatch item is too res
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

.. note::
When ``target`` contains class probabilities, it should consist of soft labels—that is,
each ``target`` entry should represent a probability distribution over the possible classes for a given data sample,
with individual probabilities between ``[0,1]`` and the total distribution summing to 1.
This is why the :func:`softmax()` function is applied to the ``target`` in the class probabilities example above.

PyTorch does not validate whether the values provided in ``target`` lie in the range ``[0,1]``
or whether the distribution of each data sample sums to ``1``.
No warning will be raised and it is the user's responsibility
to ensure that ``target`` contains valid probability distributions.
Providing arbitrary values may yield misleading loss values and unstable gradients during training.

Examples:

>>> # Example of target with incorrectly specified class probabilities
>>> loss = nn.CrossEntropyLoss()
>>> torch.manual_seed(283)
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5)
>>> # Provided target class probabilities are not in range [0,1]
>>> target
tensor([[ 0.7105, 0.4446, 2.0297, 0.2671, -0.6075],
[-1.0496, -0.2753, -0.3586, 0.9270, 1.0027],
[ 0.7551, 0.1003, 1.3468, -0.3581, -0.9569]])
>>> # Provided target class probabilities do not sum to 1
>>> target.sum(axis=1)
tensor([2.8444, 0.2462, 0.8873])
>>> # No error message and possible misleading loss value
>>> loss(input, target).item()
4.6379876136779785
>>>
>>> # Example of target with correctly specified class probabilities
>>> # Use .softmax() to ensure true probability distribution
>>> target_new = target.softmax(dim=1)
>>> # New target class probabilities all in range [0,1]
>>> target_new
tensor([[0.1559, 0.1195, 0.5830, 0.1000, 0.0417],
[0.0496, 0.1075, 0.0990, 0.3579, 0.3860],
[0.2607, 0.1355, 0.4711, 0.0856, 0.0471]])
>>> # New target class probabilities sum to 1
>>> target_new.sum(axis=1)
tensor([1.0000, 1.0000, 1.0000])
>>> loss(input, target_new).item()
2.55349063873291
"""

__constants__ = ["ignore_index", "reduction", "label_smoothing"]
Expand Down
Loading