-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add utility to get computed kernel in torch.library #158393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mikaylagawarecki/320/base
Are you sure you want to change the base?
Changes from all commits
fcdcdaf
0407e30
848ef21
51a262d
4df1e1d
1b3c50b
2b28bef
e3608bf
1797345
06b0657
ca3da8b
b224df3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
#include <c10/core/DispatchKeySet.h> | ||
#include <c10/util/TypeList.h> | ||
#include <c10/util/intrusive_ptr.h> | ||
#include <atomic> | ||
#include <memory> | ||
#include <type_traits> | ||
|
||
namespace c10 { | ||
|
@@ -17,6 +19,9 @@ class OperatorHandle; | |
struct OperatorKernel; | ||
class KernelFunction; | ||
|
||
class KernelToken; | ||
class SafeKernelFunction; | ||
|
||
template <typename T> | ||
using has_symint = std::disjunction< | ||
std::is_same<c10::SymInt, T>, | ||
|
@@ -90,6 +95,12 @@ class TORCH_API KernelFunction final { | |
BoxedKernel::BoxedKernelFunction_withDispatchKeys; | ||
|
||
KernelFunction(); | ||
~KernelFunction(); | ||
|
||
KernelFunction(const KernelFunction&) = default; | ||
KernelFunction& operator=(const KernelFunction&) = default; | ||
|
||
KernelFunction(KernelFunction&&) noexcept = default; | ||
|
||
// Fast path for dispatch to allow not touching the boxed kernel in | ||
// the common case where unboxed is available. | ||
|
@@ -262,6 +273,13 @@ class TORCH_API KernelFunction final { | |
// For testing internal invariants only | ||
bool _equalsBoxedAndUnboxed(const KernelFunction&) const; | ||
|
||
// Register a token to be invalidated when this KernelFunction is destroyed | ||
void registerToken(std::weak_ptr<KernelToken> token) const; | ||
|
||
// List of tokens that need to be invalidated when this KernelFunction is | ||
// destroyed | ||
mutable std::vector<std::weak_ptr<KernelToken>> tokens_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why mutable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also why weak_ptr and not shared_ptr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this is necessary in order to make
What's the benefit of |
||
|
||
private: | ||
explicit KernelFunction( | ||
std::unique_ptr<OperatorKernel> functor, | ||
|
@@ -278,6 +296,47 @@ class TORCH_API KernelFunction final { | |
void* sym_unboxed_kernel_func_; | ||
}; | ||
|
||
// Token held by SafeKernelFunction that gets invalidated when KernelFunction is | ||
// destroyed | ||
class KernelToken { | ||
public: | ||
bool isValid() const; | ||
void invalidate(); | ||
|
||
private: | ||
std::atomic<bool> invalid_{false}; | ||
}; | ||
|
||
class SafeKernelFunction { | ||
public: | ||
SafeKernelFunction( | ||
const KernelFunction* kernel, | ||
std::string debug, | ||
std::shared_ptr<OperatorHandle> opHandle); | ||
|
||
// Safe callBoxed - checks token validity first | ||
void callBoxed( | ||
const OperatorHandle& opHandle, | ||
DispatchKeySet dispatchKeySet, | ||
Stack* stack) const; | ||
|
||
// Get debug information | ||
const std::string& debug() const { | ||
return debug_; | ||
} | ||
|
||
// Get the OpHandle that lives on this SafeKernelFunction | ||
const OperatorHandle& opHandle() const { | ||
return *opHandle_; | ||
} | ||
|
||
private: | ||
KernelFunction kernel_; | ||
std::shared_ptr<KernelToken> token_; | ||
std::string debug_; | ||
std::shared_ptr<OperatorHandle> opHandle_; | ||
}; | ||
|
||
} // namespace c10 | ||
|
||
#include <ATen/core/boxing/KernelFunction_impl.h> |
Uh oh!
There was an error while loading. Please reload this page.