Skip to content

Commit 9d59b51

Browse files
zeshengzongpytorchmergebot
authored andcommitted
Make device check throw specific error (#155085)
Fixes #122757 The fix is lost after revert and rebase previous PR #150750 (only change of tests are merged). ## Test Result ```python >>> import torch >>> >>> model_output = torch.randn(10, 5).cuda() >>> labels = torch.randint(0, 5, (10,)).cuda() >>> weights = torch.randn(5) >>> >>> loss_fn = torch.nn.CrossEntropyLoss(weight=weights) >>> loss = loss_fn(input=model_output, target=labels) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/loss.py", line 1297, in forward return F.cross_entropy( ^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/functional.py", line 3476, in cross_entropy return torch._C._nn.cross_entropy_loss( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected all tensors to be on the same device, but got weight is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_nll_loss_forward) ``` Pull Request resolved: #155085 Approved by: https://github.com/mikaylagawarecki
1 parent 07da8a4 commit 9d59b51

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

aten/src/ATen/core/adaption.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ namespace c10::impl {
55

66
void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
77
TORCH_CHECK(false,
8-
"Expected all tensors to be on the same device, but "
9-
"found at least two devices, ", common_device, " and ", tensor.device(), "! "
10-
"(when checking argument for argument ", argName, " in method ", methodName, ")");
8+
"Expected all tensors to be on the same device, but got ", argName, " is on ", tensor.device(),
9+
", different from other tensors on ", common_device, " (when checking argument in method ", methodName, ")");
1110
}
1211

1312
} // namespace c10::impl

0 commit comments

Comments
 (0)