Skip to content

Commit ef89e2b

Browse files
Add utility to get computed kernel in torch.library
ghstack-source-id: 439450f Pull Request resolved: #158393
1 parent 900fba4 commit ef89e2b

File tree

9 files changed

+374
-38
lines changed

9 files changed

+374
-38
lines changed

aten/src/ATen/core/boxing/KernelFunction.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <c10/core/DispatchKeySet.h>
77
#include <c10/util/TypeList.h>
88
#include <c10/util/intrusive_ptr.h>
9+
#include <atomic>
10+
#include <memory>
911
#include <type_traits>
1012

1113
namespace c10 {
@@ -17,6 +19,9 @@ class OperatorHandle;
1719
struct OperatorKernel;
1820
class KernelFunction;
1921

22+
class KernelToken;
23+
class SafeKernelFunction;
24+
2025
template <typename T>
2126
using has_symint = std::disjunction<
2227
std::is_same<c10::SymInt, T>,
@@ -90,6 +95,11 @@ class TORCH_API KernelFunction final {
9095
BoxedKernel::BoxedKernelFunction_withDispatchKeys;
9196

9297
KernelFunction();
98+
~KernelFunction();
99+
100+
101+
KernelFunction(const KernelFunction&) = default;
102+
KernelFunction& operator=(const KernelFunction&) = default;
93103

94104
// Fast path for dispatch to allow not touching the boxed kernel in
95105
// the common case where unboxed is available.
@@ -262,6 +272,16 @@ class TORCH_API KernelFunction final {
262272
// For testing internal invariants only
263273
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
264274

275+
// Register a token to be invalidated when this KernelFunction is destroyed
276+
void registerToken(std::weak_ptr<KernelToken> token) const;
277+
278+
// Invalidate all registered tokens
279+
void invalidateTokens();
280+
281+
// List of tokens that need to be invalidated when this KernelFunction is
282+
// destroyed
283+
mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
284+
265285
private:
266286
explicit KernelFunction(
267287
std::unique_ptr<OperatorKernel> functor,
@@ -278,6 +298,32 @@ class TORCH_API KernelFunction final {
278298
void* sym_unboxed_kernel_func_;
279299
};
280300

301+
// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
302+
// destroyed
303+
class KernelToken {
304+
public:
305+
bool isValid() const;
306+
void invalidate();
307+
308+
private:
309+
std::atomic<bool> valid_{true};
310+
};
311+
312+
class SafeKernelFunction {
313+
public:
314+
SafeKernelFunction(const KernelFunction* kernel);
315+
316+
// Safe callBoxed - checks token validity first
317+
void callBoxed(
318+
const OperatorHandle& opHandle,
319+
DispatchKeySet dispatchKeySet,
320+
Stack* stack) const;
321+
322+
private:
323+
KernelFunction kernel_;
324+
std::shared_ptr<KernelToken> token_;
325+
};
326+
281327
} // namespace c10
282328

283329
#include <ATen/core/boxing/KernelFunction_impl.h>

aten/src/ATen/core/boxing/KernelFunction_impl.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ inline KernelFunction::KernelFunction()
2424
unboxed_kernel_func_(nullptr),
2525
sym_unboxed_kernel_func_(nullptr) {}
2626

27+
inline KernelFunction::~KernelFunction() {
28+
invalidateTokens();
29+
}
30+
2731
inline KernelFunction::KernelFunction(
2832
std::unique_ptr<OperatorKernel> functor,
2933
InternalBoxedKernelFunction* boxed_kernel_func,
@@ -157,6 +161,19 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
157161
std::forward<Args>(args)...);
158162
}
159163

164+
inline void KernelFunction::registerToken(
165+
std::weak_ptr<KernelToken> token) const {
166+
tokens_.push_back(std::move(token));
167+
}
168+
169+
inline void KernelFunction::invalidateTokens() {
170+
for (auto& weak_token : tokens_) {
171+
if (auto token = weak_token.lock()) {
172+
token->invalidate();
173+
}
174+
}
175+
}
176+
160177
inline KernelFunction KernelFunction::makeFromBoxedKernel(
161178
BoxedKernel boxed_fn) {
162179
return KernelFunction(
@@ -317,4 +334,33 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
317334
std::forward<Lambda>(lambda)));
318335
}
319336

337+
// KernelToken implementations
338+
inline bool KernelToken::isValid() const {
339+
return valid_.load(std::memory_order_acquire);
340+
}
341+
342+
inline void KernelToken::invalidate() {
343+
valid_.store(false, std::memory_order_release);
344+
}
345+
346+
// SafeKernelFunction implementations
347+
inline SafeKernelFunction::SafeKernelFunction(const KernelFunction* kernel)
348+
: kernel_(kernel ? *kernel : KernelFunction()),
349+
token_(std::make_shared<KernelToken>()) {
350+
// Register the token with the original kernel so it gets invalidated when the
351+
// kernel is destroyed
352+
if (kernel) {
353+
kernel->registerToken(token_);
354+
}
355+
}
356+
357+
inline void SafeKernelFunction::callBoxed(
358+
const OperatorHandle& opHandle,
359+
DispatchKeySet dispatchKeySet,
360+
Stack* stack) const {
361+
TORCH_CHECK(
362+
token_ && token_->isValid(), "SafeKernelFunction has been invalidated");
363+
kernel_.callBoxed(opHandle, dispatchKeySet, stack);
364+
}
365+
320366
} // namespace c10

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,10 @@ class TORCH_API OperatorHandle {
487487
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
488488
}
489489

490+
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
491+
return operatorDef_->op.getComputedKernelForDispatchKey(k);
492+
}
493+
490494
std::string dumpComputedTable() const {
491495
return operatorDef_->op.dumpComputedTable();
492496
}

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -247,23 +247,35 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
247247
void OperatorEntry::deregisterKernel_(
248248
const c10::Dispatcher& dispatcher,
249249
std::optional<DispatchKey> dispatch_key,
250-
AnnotatedKernelContainerIterator kernel
251-
) {
252-
// Redirect catchAll deregistrations to CompositeImplicitAutograd.
253-
DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::CompositeImplicitAutograd;
254-
auto found = kernels_.find(dk);
255-
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
256-
auto& k = found->second;
250+
AnnotatedKernelContainerIterator kernel) {
251+
// Redirect catchAll deregistrations to CompositeImplicitAutograd.
252+
DispatchKey dk = dispatch_key.has_value()
253+
? *dispatch_key
254+
: DispatchKey::CompositeImplicitAutograd;
255+
auto found = kernels_.find(dk);
256+
TORCH_INTERNAL_ASSERT(
257+
found != kernels_.end(),
258+
"Tried to deregister a kernel for dispatch key ",
259+
toString(dispatch_key),
260+
" but there are no kernels registered for this dispatch key. The operator is ",
261+
toString(name_));
262+
auto& k = found->second;
257263
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
258-
// We are about to remove the array from the map, no need to do anything.
264+
// We are about to remove the array from the map, no need to do anything.
259265
#else
260-
k.erase(kernel);
266+
k.erase(kernel);
261267
#endif
262-
if (k.empty()) {
263-
// the invariant says we don't want empty lists but instead remove the list from the map
264-
kernels_.erase(found);
265-
}
266-
updateDispatchTable_(dispatcher, dk);
268+
if (k.empty()) {
269+
// the invariant says we don't want empty lists but instead remove the list
270+
// from the map
271+
kernels_.erase(found);
272+
}
273+
// The KernelFunction object in dispatchTable
274+
// (1) s a copy of the object in kernels_
275+
// (2) is the one that gets returned by getComputedKernelForDispatchKey
276+
// so we cannot call invalidate tokens on kernel here and need to invalidate
277+
// the tokens on that one instead.
278+
updateDispatchTable_(dispatcher, dk, /*invalidate_tokens=*/true);
267279
}
268280

269281
void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
@@ -305,6 +317,17 @@ bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
305317
return dispatchTable_[dispatch_ix].isValid();
306318
}
307319

320+
SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
321+
DispatchKey k) const {
322+
TORCH_CHECK(
323+
!isAliasDispatchKey(k),
324+
"Alias keys do not have runtime kernel registrations.");
325+
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
326+
TORCH_CHECK(dispatchTable_[dispatch_ix].isValid())
327+
328+
return SafeKernelFunction(&dispatchTable_[dispatch_ix]);
329+
}
330+
308331
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
309332
auto kern_it = kernels_.find(dispatch_key);
310333
if (kern_it != kernels_.end()) {
@@ -450,46 +473,63 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
450473
// dispatch keys (e.g. runtime keys and their associated autograd keys,
451474
// or alias keys and their associated keysets).
452475
// This function should be considered a private helper for updateDispatchTable_()
453-
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
476+
void OperatorEntry::updateDispatchTableEntry_(
477+
const c10::Dispatcher& dispatcher,
478+
DispatchKey dispatch_key,
479+
bool invalidateTokens) {
454480
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
455481
if (C10_UNLIKELY(dispatch_ix == -1)) {
456482
return;
457483
}
458-
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
459-
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
484+
// Invalidate tokens for the old dispatch table entry before replacing it
485+
if (invalidateTokens) {
486+
dispatchTable_[dispatch_ix].invalidateTokens();
487+
}
488+
dispatchTable_[dispatch_ix] =
489+
computeDispatchTableEntry(dispatcher, dispatch_key);
490+
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(
491+
dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
460492
}
461493

462494
// synchronizes the dispatch table entries for a given dispatch key *and its
463495
// associated keys* with the current state of kernel registrations in the
464496
// dispatcher.
465497
// After a kernel has been registered to a dispatch key, a call to this
466498
// function will synchronize the dispatcher state. See e.g. registerKernel()
467-
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
468-
// Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
469-
// See Note [Undefined in dispatchTable_]
499+
void OperatorEntry::updateDispatchTable_(
500+
const c10::Dispatcher& dispatcher,
501+
DispatchKey dispatch_key,
502+
bool invalidate_tokens) {
503+
// Handle Undefined separately since it isn't a runtime key but we have an
504+
// entry in dispatchTable_. See Note [Undefined in dispatchTable_]
470505
if (dispatch_key == DispatchKey::Undefined) {
471-
updateDispatchTableEntry_(dispatcher, dispatch_key);
506+
updateDispatchTableEntry_(dispatcher, dispatch_key, invalidate_tokens);
472507
return;
473508
}
474509
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
475-
updateDispatchTableEntry_(dispatcher, k);
476-
}
477-
// Registration to CompositeExplicitAutogradNonFunctional, CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined.
478-
// We cannot do this above since Undefined cannot be represented in DispatchKeySet.
479-
if (dispatch_key == DispatchKey::CompositeImplicitAutograd
480-
|| dispatch_key == DispatchKey::CompositeExplicitAutograd
481-
|| dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
482-
updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
510+
updateDispatchTableEntry_(dispatcher, k, invalidate_tokens);
511+
}
512+
// Registration to CompositeExplicitAutogradNonFunctional,
513+
// CompositeExplicitAutograd and CompositeImplicitAutograd should be populated
514+
// to Undefined. We cannot do this above since Undefined cannot be represented
515+
// in DispatchKeySet.
516+
if (dispatch_key == DispatchKey::CompositeImplicitAutograd ||
517+
dispatch_key == DispatchKey::CompositeExplicitAutograd ||
518+
dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
519+
updateDispatchTableEntry_(
520+
dispatcher, DispatchKey::Undefined, invalidate_tokens);
483521
}
484522
// Note [Refresh Runtime Autograd entries in dispatchTable_]
485-
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
486-
// In theory, we should only have to check if the given runtime key has "dense" functionality,
487-
// e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
488-
// However, there are some backends that should be included in this set that don't have the dense key set.
489-
// E.g. DispatchKey::Meta, DispatchKey::MAIA.
523+
// Registering to backend key might affect computed entry at its Autograd
524+
// backend key due to (2.1) & (2.3). In theory, we should only have to check
525+
// if the given runtime key has "dense" functionality, e.g. DispatchKey::CPU
526+
// (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
527+
// However, there are some backends that should be included in this set that
528+
// don't have the dense key set. E.g. DispatchKey::Meta, DispatchKey::MAIA.
490529
if (c10::isBackendDispatchKey(dispatch_key)) {
491-
DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
492-
updateDispatchTableEntry_(dispatcher, autograd_key);
530+
DispatchKey autograd_key =
531+
getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
532+
updateDispatchTableEntry_(dispatcher, autograd_key, invalidate_tokens);
493533
}
494534
}
495535

aten/src/ATen/core/dispatch/OperatorEntry.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ class TORCH_API OperatorEntry final {
217217
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
218218
// Returns true if the "computed table" has an entry for a particular key.
219219
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
220+
// Returns a KernelFunction corresponding to the kernel in dispatchTable
221+
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
220222
// Returns all the operator tags added at the time of registration
221223
const std::vector<at::Tag>& getTags() const;
222224
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
@@ -318,11 +320,13 @@ class TORCH_API OperatorEntry final {
318320
// dispatch key.
319321
void updateDispatchTableEntry_(
320322
const c10::Dispatcher& dispatcher,
321-
DispatchKey dispatch_key);
323+
DispatchKey dispatch_key,
324+
bool invalidateTokens = false);
322325
// Like above, but also handles alias dispatch keys.
323326
void updateDispatchTable_(
324327
const c10::Dispatcher& dispatcher,
325-
DispatchKey dispatch_key);
328+
DispatchKey dispatch_key,
329+
bool invalidateTokens = false);
326330
// Like above, but for ALL entries in the dispatch table.
327331
void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
328332
// Retrieves a pointer to AnnotatedKernel at

0 commit comments

Comments
 (0)