Skip to content

Commit d88bfc9

Browse files
author
chronos_secgrp_pytorch_oss_ci_oncall
committed
2021-07-15 nightly release (9496521)
1 parent 7a31774 commit d88bfc9

File tree

112 files changed

+3593
-2161
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+3593
-2161
lines changed

aten/src/ATen/BatchingRegistrations.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1177,10 +1177,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
11771177
TRIVIAL_OP(imag)
11781178
TRIVIAL_OP(real);
11791179
TRIVIAL_OP(view_as_real);
1180-
TRIVIAL_OP(_view_as_real_physical);
11811180
TRIVIAL_OP(conj);
11821181
TRIVIAL_OP(_conj);
11831182
TRIVIAL_OP(resolve_conj);
1183+
TRIVIAL_OP(resolve_neg);
11841184
m.impl("view_as_complex", view_as_complex_batching_rule);
11851185
#undef TRIVIAL
11861186

aten/src/ATen/ConjugateFallback.cpp

+16-109
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,26 @@
1-
#include <ATen/ATen.h>
2-
#include <ATen/core/dispatch/Dispatcher.h>
3-
#include <ATen/core/op_registration/op_registration.h>
4-
#include <ATen/native/UnaryOps.h>
5-
#include <ATen/NativeFunctions.h>
6-
#include <c10/util/irange.h>
7-
#include <torch/library.h>
1+
#include <ATen/native/MathBitsFallback.h>
82

93
namespace at {
104

11-
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
12-
// Situations to handle:
13-
// 1. Out-of-place operation. Easy: materialize all inputs and
14-
// call it a day.
15-
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
16-
// Materialize other inputs as in (1).
17-
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
18-
// Materialize other inputs as in (1).
19-
//
20-
// It is important to be able to tell if we READ from an argument and if we
21-
// WRITE from an argument. Conservative approach is to assume that we always
22-
// READ from an argument, but in out-of-place operations you can skip
23-
// conjugating inputs on entry that never get used. In current schema we
24-
// can't easily tell if inplace situation has happened, so don't do it.
25-
26-
const auto& arguments = op.schema().arguments();
27-
const auto num_arguments = arguments.size();
28-
const auto stack_start = stack->size() - num_arguments;
29-
30-
c10::optional<bool> is_write;
31-
for (const auto i : c10::irange(num_arguments)) {
32-
const auto& alias_info = arguments[i].alias_info();
33-
// Three possible states:
34-
// 1. alias_info has no value --> out-of-place operation
35-
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
36-
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
37-
if (alias_info.has_value()) {
38-
if (is_write.has_value()) {
39-
TORCH_CHECK(*is_write == alias_info->isWrite(),
40-
"Unsupported operator for conjugate fallback: ", op.schema().name(),
41-
"Conjugate fallback doesn't work for operators with a mix "
42-
"mutable and non-mutable inputs that alias with outputs, "
43-
"this must be implemented manually. "
44-
"If you got this error on a core op, please report a bug to PyTorch.");
45-
} else {
46-
is_write = alias_info->isWrite();
47-
}
48-
}
5+
struct ConjFallback : MathOpFallback {
6+
ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate") {}
7+
bool is_bit_set(const Tensor& tensor) override {
8+
return tensor.is_conj();
499
}
50-
51-
if (is_write.has_value() && !*is_write) {
52-
// We assume that view operators automatically handle conjugation
53-
// correctly by propagating the Conjugate dispatch key in key_set.
54-
// This is not necessarily always right, so you should test these cases.
55-
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
56-
return;
10+
void _set_bit(const Tensor& tensor, bool value) override {
11+
return tensor._set_conj(value);
5712
}
58-
59-
// Mutable inputs to be tracked separately
60-
std::vector<Tensor> mutable_inputs;
61-
62-
for (const auto i : c10::irange(num_arguments)) {
63-
auto& ivalue = (*stack)[stack_start + i];
64-
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
65-
continue;
66-
}
67-
const auto& argument = arguments[i];
68-
bool mut_arg = false;
69-
if (argument.alias_info()) {
70-
// View operations were already filtered above, so only in-place/out= operations should get here.
71-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
72-
mut_arg = true;
73-
}
74-
if (ivalue.isTensor()) {
75-
auto* impl = ivalue.unsafeToTensorImpl();
76-
if (!impl->is_conj()) {
77-
continue;
78-
}
79-
80-
auto tensor = std::move(ivalue).toTensor();
81-
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors.");
82-
if (mut_arg) {
83-
// TODO: This is a waste if the argument is write only
84-
tensor._set_conj(false);
85-
at::conj_physical_(tensor);
86-
mutable_inputs.emplace_back(tensor);
87-
} else {
88-
tensor = at::resolve_conj(tensor);
89-
}
90-
(*stack)[stack_start + i] = std::move(tensor);
91-
} else if (ivalue.isTensorList()) {
92-
auto tensors = std::move(ivalue).toTensorList();
93-
if (mut_arg) {
94-
for(const auto j : c10::irange(tensors.size())) {
95-
Tensor t = tensors[j];
96-
t._set_conj(false);
97-
at::conj_physical_(t);
98-
mutable_inputs.emplace_back(t);
99-
}
100-
} else {
101-
for(const auto j : c10::irange(tensors.size())) {
102-
tensors[j] = at::resolve_conj(tensors[j]);
103-
}
104-
}
105-
(*stack)[stack_start + i] = std::move(tensors);
106-
}
13+
Tensor resolve_bit(const Tensor& tensor) override {
14+
return at::resolve_conj(tensor);
10715
}
16+
Tensor& math_op_(Tensor& tensor) override {
17+
return at::conj_physical_(tensor);
18+
}
19+
};
10820

109-
110-
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
111-
112-
for (auto& mutable_input : mutable_inputs) {
113-
at::conj_physical_(mutable_input);
114-
mutable_input._set_conj(true);
115-
}
21+
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
22+
ConjFallback object;
23+
object.fallback_impl(op, dispatch_keys, stack);
11624
}
11725

11826
TORCH_LIBRARY_IMPL(_, Conjugate, m) {
@@ -142,7 +50,6 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
14250
m.impl("size.int", torch::CppFunction::makeFallthrough());
14351
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
14452
m.impl("is_complex", torch::CppFunction::makeFallthrough());
145-
m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough());
14653
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
14754
m.impl("imag", torch::CppFunction::makeFallthrough());
14855
m.impl("real", torch::CppFunction::makeFallthrough());

aten/src/ATen/core/aten_interned_strings.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ _(aten, conj) \
241241
_(aten, conj_physical) \
242242
_(aten, conj_physical_) \
243243
_(aten, resolve_conj) \
244+
_(aten, resolve_neg) \
244245
_(aten, complex) \
245246
_(aten, copysign) \
246247
_(aten, polar) \
@@ -768,7 +769,6 @@ _(aten, zeros_like) \
768769
_(aten, real) \
769770
_(aten, imag) \
770771
_(aten, view_as_real) \
771-
_(aten, _view_as_real_physical) \
772772
_(aten, view_as_complex) \
773773
/* nothing */
774774

Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#pragma once
22

33
#include <ATen/cpu/vec/vec256/functional_base.h>
4+
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
45
#include <ATen/cpu/vec/vec256/functional_bfloat16.h>
6+
#endif

aten/src/ATen/native/BinaryOps.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,18 @@ Tensor add_relu(const Tensor& self, const Tensor& other, const Scalar& alpha) {
368368
return add_relu_impl(result, self, other, alpha);
369369
}
370370

371+
Tensor add_relu(const Tensor& self, const Scalar& other, const Scalar& alpha) {
372+
return add_relu(self, wrapped_scalar_tensor(other), alpha);
373+
}
374+
371375
Tensor& add_relu_(Tensor& self, const Tensor& other, const Scalar& alpha) {
372376
return add_relu_impl(self, self, other, alpha);
373377
}
374378

379+
Tensor& add_relu_(Tensor& self, const Scalar& other, const Scalar& alpha) {
380+
return add_relu_(self, wrapped_scalar_tensor(other), alpha);
381+
}
382+
375383
TORCH_IMPL_FUNC(copysign_out) (
376384
const Tensor& self, const Tensor& other, const Tensor& result
377385
) {

aten/src/ATen/native/ComplexHelper.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,8 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) {
3131
return res;
3232
}
3333

34-
// expects as input a complex tensor and returns back a tensor
35-
// with corresponding real dtype containing the complex values
36-
// in the last two dimensions
37-
Tensor view_as_real(const Tensor& self) {
38-
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
39-
return native::_view_as_real_physical(self);
40-
}
41-
4234
Tensor _view_as_real_physical(const Tensor& self) {
43-
TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors");
35+
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
4436
auto old_sizes = self.sizes();
4537
DimVector new_sizes(old_sizes.size() + 1);
4638
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
@@ -53,6 +45,14 @@ Tensor _view_as_real_physical(const Tensor& self) {
5345
return real_tensor;
5446
}
5547

48+
// expects as input a complex tensor and returns back a tensor
49+
// with corresponding real dtype containing the complex values
50+
// in the last two dimensions
51+
Tensor view_as_real(const Tensor& self) {
52+
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
53+
return _view_as_real_physical(self);
54+
}
55+
5656
inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {
5757
const int64_t dim = oldstride.size();
5858
TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1");
+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/op_registration/op_registration.h>
3+
#include <torch/library.h>
4+
#include <ATen/core/dispatch/Dispatcher.h>
5+
#include <ATen/native/UnaryOps.h>
6+
#include <ATen/NativeFunctions.h>
7+
8+
namespace at {
9+
10+
// This fallback should only be used for operations that are self inverse and have a corresponding tensor
11+
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
12+
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
13+
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
14+
15+
struct MathOpFallback {
16+
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(op_name_) {}
17+
virtual bool is_bit_set(const Tensor&) = 0;
18+
virtual void _set_bit(const Tensor&, bool) = 0;
19+
// materializes the bit, i.e., returns a new tensor tensor containing the true output
20+
// (after performing the math operation corresponding to the tensor bit) if the bit is set to 1
21+
// else returns self.
22+
virtual Tensor resolve_bit(const Tensor&) = 0;
23+
// in-place operation corresponding to the math op represented by the bit. Im the future if this class
24+
// is generalized for ops that are not self inverse, then this must be replaced by op_inverse_inplace
25+
virtual Tensor& math_op_(Tensor&) = 0;
26+
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
27+
// Situations to handle:
28+
// 1. Out-of-place operation. Easy: materialize all inputs and
29+
// call it a day.
30+
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
31+
// Materialize other inputs as in (1).
32+
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
33+
// Materialize other inputs as in (1).
34+
//
35+
// It is important to be able to tell if we READ from an argument and if we
36+
// WRITE from an argument. Conservative approach is to assume that we always
37+
// READ from an argument, but in out-of-place operations you can skip
38+
// conjugating inputs on entry that never get used. In current schema we
39+
// can't easily tell if inplace situation has happened, so don't do it.
40+
41+
const auto& arguments = op.schema().arguments();
42+
const auto num_arguments = arguments.size();
43+
const auto stack_start = stack->size() - num_arguments;
44+
45+
c10::optional<bool> is_write;
46+
for (int64_t i = 0; i < num_arguments; ++i) {
47+
// Three possible states:
48+
// 1. alias_info has no value --> out-of-place operation
49+
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
50+
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
51+
const auto& alias_info = arguments[i].alias_info();
52+
if (alias_info.has_value()) {
53+
if (is_write.has_value()) {
54+
TORCH_CHECK(*is_write == alias_info->isWrite(),
55+
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
56+
op_name, " fallback doesn't work for operators with a mix "
57+
"mutable and non-mutable inputs that alias with outputs, "
58+
"this must be implemented manually. "
59+
"If you got this error on a core op, please report a bug to PyTorch.");
60+
} else {
61+
is_write = alias_info->isWrite();
62+
}
63+
}
64+
}
65+
66+
if (is_write.has_value() && !*is_write) {
67+
// We assume that view operators automatically handle the math bit
68+
// correctly by propagating the dispatch key in key_set.
69+
// This is not necessarily always right, so you should test these cases.
70+
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
71+
return;
72+
}
73+
74+
// Mutable inputs to be tracked separately
75+
std::vector<Tensor> mutable_inputs;
76+
77+
for (int64_t i = 0; i < num_arguments; ++i) {
78+
auto& ivalue = (*stack)[stack_start + i];
79+
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
80+
continue;
81+
}
82+
const auto& argument = arguments[i];
83+
bool mut_arg = false;
84+
if (argument.alias_info()) {
85+
// Was already tested by is_write loop above
86+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
87+
mut_arg = true;
88+
}
89+
if (ivalue.isTensor()) {
90+
if (!is_bit_set(ivalue.toTensor())) {
91+
continue;
92+
}
93+
94+
auto tensor = std::move(ivalue).toTensor();
95+
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), op_name, " fallback does not support meta tensors.");
96+
if (mut_arg) {
97+
// TODO: This is a waste if the argument is write only
98+
_set_bit(tensor, false);
99+
math_op_(tensor);
100+
mutable_inputs.emplace_back(tensor);
101+
} else {
102+
tensor = resolve_bit(tensor);
103+
}
104+
(*stack)[stack_start + i] = std::move(tensor);
105+
} else if (ivalue.isTensorList()) {
106+
auto tensors = std::move(ivalue).toTensorList();
107+
if (mut_arg) {
108+
for(const auto j : c10::irange(tensors.size())) {
109+
Tensor t = tensors[j];
110+
_set_bit(t, false);
111+
math_op_(t);
112+
mutable_inputs.emplace_back(t);
113+
}
114+
} else {
115+
for(const auto j : c10::irange(tensors.size())) {
116+
tensors[j] = resolve_bit(tensors[j]);
117+
}
118+
}
119+
(*stack)[stack_start + i] = std::move(tensors);
120+
}
121+
}
122+
123+
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
124+
125+
for (auto& mutable_input : mutable_inputs) {
126+
math_op_(mutable_input);
127+
_set_bit(mutable_input, true);
128+
}
129+
}
130+
131+
virtual ~MathOpFallback() = default;
132+
133+
DispatchKey key;
134+
string op_name;
135+
};
136+
137+
} // namespace at

0 commit comments

Comments
 (0)