Skip to content

Commit b7818f6

Browse files
Removed expand_as() call in aminmax_backward
1 parent f7e09b0 commit b7818f6

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,13 @@ Tensor aminmax_backward(
238238

239239
if (grad_min.defined()) {
240240
auto grad_min_full =
241-
restore_reduced_dims(grad_min, dims, keepdim)
242-
.expand_as(min_mask);
241+
restore_reduced_dims(grad_min, dims, keepdim);
243242
result = scale_grad_by_count(grad_min_full, min_mask, dims);
244243
}
245244

246245
if (grad_max.defined()) {
247246
auto grad_max_full =
248-
restore_reduced_dims(grad_max, dims, keepdim)
249-
.expand_as(max_mask);
247+
restore_reduced_dims(grad_max, dims, keepdim);
250248
auto grad_max_res = scale_grad_by_count(grad_max_full, max_mask, dims);
251249
result = result.defined() ? result + grad_max_res
252250
: grad_max_res;

0 commit comments

Comments
 (0)