-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add label_smoothing
param in nn.BCELoss
and nn.BCEWithLogitsLoss
#150282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150282
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b4750fb with merge base 842cc77 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
4ffccbd
to
f088eaa
Compare
Hello @jbschlosser @mikaylagawarecki, please help review this PR, thanks! |
Hello @jbschlosser @mikaylagawarecki, please help review the change when available, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python side impl looks good to me, thanks for the contribution :)
I will leave it to torch.nn maintainer @mikaylagawarecki for the final stamp though!
Hi @zeshengzong I'll leave a proper review on this before the end of the week. Thank you for your patience and your multiple contributions! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just nit comments, thanks!
torch/nn/functional.py
Outdated
) -> Tensor: | ||
r"""Compute Binary Cross Entropy between the target and input probabilities. | ||
See :class:`~torch.nn.BCELoss` for details. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, why delete the args here and below? The convention in this file seems to be to document the args even though there's the see :class: ... on all the ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, previously I saw a comment said documents of torch.nn.functional methods intentionally left empty to avoid duplication with class document, like torch.nn.functional.adaptive_avg_pool1d, torch.nn.functional.adaptive_avg_pool2d, torch.nn.functional.adaptive_avg_pool3d, but there are some methods do have param documents like this one.
I think it would be better to consistent content of torch.nn.functional
documents (either all have param doc, or all left empty and guide user to class doc), avoid people think it was a mistake for those don't have param docs. But I'm not sure which is the right way to fix them. WDYT? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, hmm I'm not sure what you mean as it seems the Args are still documented for adaptive_avg_pool1d
pytorch/torch/nn/functional.py
Lines 1357 to 1358 in d596624
Args: | |
output_size: the target output size (single integer) |
I think it would be good not to delete these. If there is an inconsistent convention we can resolve that in a separate PR
test/test_nn.py
Outdated
def test_bce_label_smoothing_errors(self): | ||
N, C = 3, 4 | ||
inputs = torch.randn((N, C)) | ||
target = torch.randn((N, C)) | ||
for loss_fn in (nn.BCELoss, nn.BCEWithLogitsLoss): | ||
loss = loss_fn(label_smoothing=1.2) | ||
with self.assertRaisesRegex(AssertionError, | ||
r"label_smoothing must be between 0\.0"): | ||
loss(inputs, target) | ||
|
||
def test_bce_label_smoothing(self): | ||
N, C = 3, 4 | ||
inputs = torch.rand((N, C)) | ||
target = torch.rand((N, C)) | ||
label_smoothings = [0.05, 0.15] | ||
|
||
for loss_fn, label_smoothing in product([nn.BCELoss, nn.BCEWithLogitsLoss], label_smoothings): | ||
loss = loss_fn(label_smoothing=label_smoothing) | ||
output_with_smoothing = loss(inputs, target) | ||
target_with_smoothing = target * (1 - label_smoothing) + (1 - target) * label_smoothing | ||
loss = loss_fn() | ||
output_with_manual_smoothing = loss(inputs, target_with_smoothing) | ||
self.assertEqual(output_with_smoothing, output_with_manual_smoothing) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to update the ModuleInputs instead :) (test/test_nn.py is somewhat legacy)
pytorch/torch/testing/_internal/common_modules.py
Lines 1455 to 1542 in 2247aa6
def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs): | |
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | |
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) | |
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) | |
cases: list[tuple[str, dict]] = [ | |
('', {}), | |
('reduction_sum', {'reduction': 'sum'}), | |
('reduction_mean', {'reduction': 'mean'}), | |
('reduction_none', {'reduction': 'none'}), | |
('weights', {'weight': make_weight((10,))}), | |
] | |
def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): | |
result = -(t * i.log() + (1 - t) * (1 - i).log()) | |
if weight is not None: | |
result = result * weight | |
if reduction == 'none': | |
return result | |
elif reduction == 'mean': | |
return result.sum() / i.numel() | |
else: | |
return result.sum() | |
module_inputs = [] | |
for desc, constructor_kwargs in cases: | |
module_inputs.append( | |
ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), | |
forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2), | |
make_target((15, 10)).gt(0).to(dtype)), | |
desc=desc, | |
reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs)) | |
) | |
scalar_weight = make_weight(()) | |
module_inputs.append( | |
ModuleInput(constructor_input=FunctionInput(weight=scalar_weight), | |
forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2), | |
make_target(()).gt(0).to(dtype)), | |
desc='scalar_weight', | |
reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight)) | |
) | |
return module_inputs | |
def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs): | |
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | |
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) | |
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) | |
cases: list[tuple[str, dict]] = [ | |
('', {}), | |
('reduction_sum', {'reduction': 'sum'}), | |
('reduction_mean', {'reduction': 'mean'}), | |
('reduction_none', {'reduction': 'none'}), | |
('weights', {'weight': make_weight((10,))}), | |
('scalar_weights', {'weight': make_weight(())}) | |
] | |
def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None): | |
# TODO: add pos_weight to the definition here and corresponding SampleInputs | |
max_val = (-i).clamp(min=0) | |
result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_()) | |
if weight is not None: | |
result = result * weight | |
if reduction == 'none': | |
return result | |
elif reduction == 'mean': | |
return result.sum() / i.numel() | |
else: | |
return result.sum() | |
module_inputs = [] | |
for desc, constructor_kwargs in cases: | |
module_inputs.append( | |
ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), | |
forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2), | |
make_target((15, 10)).gt(0).to(dtype)), | |
desc=desc, | |
reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs)) | |
) | |
return module_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed, thanks!
653b42e
to
d596624
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, please run lint again and remove the Args deletion for now :)
Changed, thanks! |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #91545
Changes
label_smoothing
param and docslabel_smoothing
nn.BCELoss
andnn.BCEWithLogitsLoss
Test Result