Skip to content

Add label_smoothing param in nn.BCELoss and nn.BCEWithLogitsLoss #150282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Mar 31, 2025

Fixes #91545

Changes

  • Add label_smoothing param and docs
  • Add test case for label_smoothing
  • Remove duplicate description in nn.BCELoss and nn.BCEWithLogitsLoss

Test Result

pytest -s test/test_nn.py -k test_bce

image

image

image

Copy link

pytorch-bot bot commented Mar 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150282

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 8 Pending

As of commit b4750fb with merge base 842cc77 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Mar 31, 2025
@zeshengzong zeshengzong marked this pull request as ready for review March 31, 2025 08:40
@albanD albanD removed their request for review April 7, 2025 19:50
@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased opt/nn/bce onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout opt/nn/bce && git pull --rebase)

@zeshengzong
Copy link
Contributor Author

Hello @jbschlosser @mikaylagawarecki, please help review this PR, thanks!

@zeshengzong
Copy link
Contributor Author

Hello @jbschlosser @mikaylagawarecki, please help review the change when available, thanks!

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python side impl looks good to me, thanks for the contribution :)

I will leave it to torch.nn maintainer @mikaylagawarecki for the final stamp though!

@mikaylagawarecki
Copy link
Contributor

Hi @zeshengzong I'll leave a proper review on this before the end of the week. Thank you for your patience and your multiple contributions!

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just nit comments, thanks!

) -> Tensor:
r"""Compute Binary Cross Entropy between the target and input probabilities.
See :class:`~torch.nn.BCELoss` for details.
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, why delete the args here and below? The convention in this file seems to be to document the args even though there's the see :class: ... on all the ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, previously I saw a comment said documents of torch.nn.functional methods intentionally left empty to avoid duplication with class document, like torch.nn.functional.adaptive_avg_pool1d, torch.nn.functional.adaptive_avg_pool2d, torch.nn.functional.adaptive_avg_pool3d, but there are some methods do have param documents like this one.

I think it would be better to consistent content of torch.nn.functional documents (either all have param doc, or all left empty and guide user to class doc), avoid people think it was a mistake for those don't have param docs. But I'm not sure which is the right way to fix them. WDYT? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, hmm I'm not sure what you mean as it seems the Args are still documented for adaptive_avg_pool1d

pytorch/torch/nn/functional.py

Lines 1357 to 1358 in d596624

Args:
output_size: the target output size (single integer)

I think it would be good not to delete these. If there is an inconsistent convention we can resolve that in a separate PR

test/test_nn.py Outdated
Comment on lines 4700 to 4723
def test_bce_label_smoothing_errors(self):
N, C = 3, 4
inputs = torch.randn((N, C))
target = torch.randn((N, C))
for loss_fn in (nn.BCELoss, nn.BCEWithLogitsLoss):
loss = loss_fn(label_smoothing=1.2)
with self.assertRaisesRegex(AssertionError,
r"label_smoothing must be between 0\.0"):
loss(inputs, target)

def test_bce_label_smoothing(self):
N, C = 3, 4
inputs = torch.rand((N, C))
target = torch.rand((N, C))
label_smoothings = [0.05, 0.15]

for loss_fn, label_smoothing in product([nn.BCELoss, nn.BCEWithLogitsLoss], label_smoothings):
loss = loss_fn(label_smoothing=label_smoothing)
output_with_smoothing = loss(inputs, target)
target_with_smoothing = target * (1 - label_smoothing) + (1 - target) * label_smoothing
loss = loss_fn()
output_with_manual_smoothing = loss(inputs, target_with_smoothing)
self.assertEqual(output_with_smoothing, output_with_manual_smoothing)

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to update the ModuleInputs instead :) (test/test_nn.py is somewhat legacy)

def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
('reduction_none', {'reduction': 'none'}),
('weights', {'weight': make_weight((10,))}),
]
def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
result = -(t * i.log() + (1 - t) * (1 - i).log())
if weight is not None:
result = result * weight
if reduction == 'none':
return result
elif reduction == 'mean':
return result.sum() / i.numel()
else:
return result.sum()
module_inputs = []
for desc, constructor_kwargs in cases:
module_inputs.append(
ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
make_target((15, 10)).gt(0).to(dtype)),
desc=desc,
reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs))
)
scalar_weight = make_weight(())
module_inputs.append(
ModuleInput(constructor_input=FunctionInput(weight=scalar_weight),
forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2),
make_target(()).gt(0).to(dtype)),
desc='scalar_weight',
reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight))
)
return module_inputs
def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
('reduction_none', {'reduction': 'none'}),
('weights', {'weight': make_weight((10,))}),
('scalar_weights', {'weight': make_weight(())})
]
def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None):
# TODO: add pos_weight to the definition here and corresponding SampleInputs
max_val = (-i).clamp(min=0)
result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_())
if weight is not None:
result = result * weight
if reduction == 'none':
return result
elif reduction == 'mean':
return result.sum() / i.numel()
else:
return result.sum()
module_inputs = []
for desc, constructor_kwargs in cases:
module_inputs.append(
ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
make_target((15, 10)).gt(0).to(dtype)),
desc=desc,
reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs))
)
return module_inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks!

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, please run lint again and remove the Args deletion for now :)

@zeshengzong
Copy link
Contributor Author

Changed, thanks!

@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request merging open source release notes: nn release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding label smoothing option to nn.BCELoss and nn.BCEWithLogitsLoss?
6 participants