|
1 |
| -#include <ATen/ATen.h> |
2 |
| -#include <ATen/core/dispatch/Dispatcher.h> |
3 |
| -#include <ATen/core/op_registration/op_registration.h> |
4 |
| -#include <ATen/native/UnaryOps.h> |
5 |
| -#include <ATen/NativeFunctions.h> |
6 |
| -#include <c10/util/irange.h> |
7 |
| -#include <torch/library.h> |
| 1 | +#include <ATen/native/MathBitsFallback.h> |
8 | 2 |
|
9 | 3 | namespace at {
|
10 | 4 |
|
11 |
| -void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { |
12 |
| - // Situations to handle: |
13 |
| - // 1. Out-of-place operation. Easy: materialize all inputs and |
14 |
| - // call it a day. |
15 |
| - // 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_(). |
16 |
| - // Materialize other inputs as in (1). |
17 |
| - // 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2)) |
18 |
| - // Materialize other inputs as in (1). |
19 |
| - // |
20 |
| - // It is important to be able to tell if we READ from an argument and if we |
21 |
| - // WRITE from an argument. Conservative approach is to assume that we always |
22 |
| - // READ from an argument, but in out-of-place operations you can skip |
23 |
| - // conjugating inputs on entry that never get used. In current schema we |
24 |
| - // can't easily tell if inplace situation has happened, so don't do it. |
25 |
| - |
26 |
| - const auto& arguments = op.schema().arguments(); |
27 |
| - const auto num_arguments = arguments.size(); |
28 |
| - const auto stack_start = stack->size() - num_arguments; |
29 |
| - |
30 |
| - c10::optional<bool> is_write; |
31 |
| - for (const auto i : c10::irange(num_arguments)) { |
32 |
| - const auto& alias_info = arguments[i].alias_info(); |
33 |
| - // Three possible states: |
34 |
| - // 1. alias_info has no value --> out-of-place operation |
35 |
| - // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation |
36 |
| - // 3. alias_info does have a value, alias_info->is_write=False --> view operation |
37 |
| - if (alias_info.has_value()) { |
38 |
| - if (is_write.has_value()) { |
39 |
| - TORCH_CHECK(*is_write == alias_info->isWrite(), |
40 |
| - "Unsupported operator for conjugate fallback: ", op.schema().name(), |
41 |
| - "Conjugate fallback doesn't work for operators with a mix " |
42 |
| - "mutable and non-mutable inputs that alias with outputs, " |
43 |
| - "this must be implemented manually. " |
44 |
| - "If you got this error on a core op, please report a bug to PyTorch."); |
45 |
| - } else { |
46 |
| - is_write = alias_info->isWrite(); |
47 |
| - } |
48 |
| - } |
| 5 | +struct ConjFallback : MathOpFallback { |
| 6 | + ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate") {} |
| 7 | + bool is_bit_set(const Tensor& tensor) override { |
| 8 | + return tensor.is_conj(); |
49 | 9 | }
|
50 |
| - |
51 |
| - if (is_write.has_value() && !*is_write) { |
52 |
| - // We assume that view operators automatically handle conjugation |
53 |
| - // correctly by propagating the Conjugate dispatch key in key_set. |
54 |
| - // This is not necessarily always right, so you should test these cases. |
55 |
| - op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack); |
56 |
| - return; |
| 10 | + void _set_bit(const Tensor& tensor, bool value) override { |
| 11 | + return tensor._set_conj(value); |
57 | 12 | }
|
58 |
| - |
59 |
| - // Mutable inputs to be tracked separately |
60 |
| - std::vector<Tensor> mutable_inputs; |
61 |
| - |
62 |
| - for (const auto i : c10::irange(num_arguments)) { |
63 |
| - auto& ivalue = (*stack)[stack_start + i]; |
64 |
| - if (!(ivalue.isTensor() || ivalue.isTensorList())) { |
65 |
| - continue; |
66 |
| - } |
67 |
| - const auto& argument = arguments[i]; |
68 |
| - bool mut_arg = false; |
69 |
| - if (argument.alias_info()) { |
70 |
| - // View operations were already filtered above, so only in-place/out= operations should get here. |
71 |
| - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); |
72 |
| - mut_arg = true; |
73 |
| - } |
74 |
| - if (ivalue.isTensor()) { |
75 |
| - auto* impl = ivalue.unsafeToTensorImpl(); |
76 |
| - if (!impl->is_conj()) { |
77 |
| - continue; |
78 |
| - } |
79 |
| - |
80 |
| - auto tensor = std::move(ivalue).toTensor(); |
81 |
| - TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors."); |
82 |
| - if (mut_arg) { |
83 |
| - // TODO: This is a waste if the argument is write only |
84 |
| - tensor._set_conj(false); |
85 |
| - at::conj_physical_(tensor); |
86 |
| - mutable_inputs.emplace_back(tensor); |
87 |
| - } else { |
88 |
| - tensor = at::resolve_conj(tensor); |
89 |
| - } |
90 |
| - (*stack)[stack_start + i] = std::move(tensor); |
91 |
| - } else if (ivalue.isTensorList()) { |
92 |
| - auto tensors = std::move(ivalue).toTensorList(); |
93 |
| - if (mut_arg) { |
94 |
| - for(const auto j : c10::irange(tensors.size())) { |
95 |
| - Tensor t = tensors[j]; |
96 |
| - t._set_conj(false); |
97 |
| - at::conj_physical_(t); |
98 |
| - mutable_inputs.emplace_back(t); |
99 |
| - } |
100 |
| - } else { |
101 |
| - for(const auto j : c10::irange(tensors.size())) { |
102 |
| - tensors[j] = at::resolve_conj(tensors[j]); |
103 |
| - } |
104 |
| - } |
105 |
| - (*stack)[stack_start + i] = std::move(tensors); |
106 |
| - } |
| 13 | + Tensor resolve_bit(const Tensor& tensor) override { |
| 14 | + return at::resolve_conj(tensor); |
107 | 15 | }
|
| 16 | + Tensor& math_op_(Tensor& tensor) override { |
| 17 | + return at::conj_physical_(tensor); |
| 18 | + } |
| 19 | +}; |
108 | 20 |
|
109 |
| - |
110 |
| - op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack); |
111 |
| - |
112 |
| - for (auto& mutable_input : mutable_inputs) { |
113 |
| - at::conj_physical_(mutable_input); |
114 |
| - mutable_input._set_conj(true); |
115 |
| - } |
| 21 | +void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { |
| 22 | + ConjFallback object; |
| 23 | + object.fallback_impl(op, dispatch_keys, stack); |
116 | 24 | }
|
117 | 25 |
|
118 | 26 | TORCH_LIBRARY_IMPL(_, Conjugate, m) {
|
@@ -142,7 +50,6 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
|
142 | 50 | m.impl("size.int", torch::CppFunction::makeFallthrough());
|
143 | 51 | m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
|
144 | 52 | m.impl("is_complex", torch::CppFunction::makeFallthrough());
|
145 |
| - m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough()); |
146 | 53 | m.impl("view_as_real", torch::CppFunction::makeFallthrough());
|
147 | 54 | m.impl("imag", torch::CppFunction::makeFallthrough());
|
148 | 55 | m.impl("real", torch::CppFunction::makeFallthrough());
|
|
0 commit comments