diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 06bcc5d4f49b..62071b97452e 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include namespace c10 { @@ -17,6 +19,9 @@ class OperatorHandle; struct OperatorKernel; class KernelFunction; +class KernelToken; +class SafeKernelFunction; + template using has_symint = std::disjunction< std::is_same, @@ -90,6 +95,12 @@ class TORCH_API KernelFunction final { BoxedKernel::BoxedKernelFunction_withDispatchKeys; KernelFunction(); + ~KernelFunction(); + + KernelFunction(const KernelFunction&) = default; + KernelFunction& operator=(const KernelFunction&) = default; + + KernelFunction(KernelFunction&&) noexcept = default; // Fast path for dispatch to allow not touching the boxed kernel in // the common case where unboxed is available. @@ -262,6 +273,13 @@ class TORCH_API KernelFunction final { // For testing internal invariants only bool _equalsBoxedAndUnboxed(const KernelFunction&) const; + // Register a token to be invalidated when this KernelFunction is destroyed + void registerToken(std::weak_ptr token) const; + + // List of tokens that need to be invalidated when this KernelFunction is + // destroyed + mutable std::vector> tokens_; + private: explicit KernelFunction( std::unique_ptr functor, @@ -278,6 +296,47 @@ class TORCH_API KernelFunction final { void* sym_unboxed_kernel_func_; }; +// Token held by SafeKernelFunction that gets invalidated when KernelFunction is +// destroyed +class KernelToken { + public: + bool isValid() const; + void invalidate(); + + private: + std::atomic invalid_{false}; +}; + +class SafeKernelFunction { + public: + SafeKernelFunction( + const KernelFunction* kernel, + std::string debug, + std::shared_ptr opHandle); + + // Safe callBoxed - checks token validity first + void callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + // Get debug information + const std::string& debug() const { + return debug_; + } + + // Get the OpHandle that lives on this SafeKernelFunction + const OperatorHandle& opHandle() const { + return *opHandle_; + } + + private: + KernelFunction kernel_; + std::shared_ptr token_; + std::string debug_; + std::shared_ptr opHandle_; +}; + } // namespace c10 #include diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index df49d6227ee9..dc31ac7a6c34 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -24,6 +24,14 @@ inline KernelFunction::KernelFunction() unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {} +inline KernelFunction::~KernelFunction() { + for (auto& weak_token : tokens_) { + if (auto token = weak_token.lock()) { + token->invalidate(); + } + } +} + inline KernelFunction::KernelFunction( std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, @@ -157,6 +165,11 @@ C10_ALWAYS_INLINE Return KernelFunction::call( std::forward(args)...); } +inline void KernelFunction::registerToken( + std::weak_ptr token) const { + tokens_.push_back(std::move(token)); +} + inline KernelFunction KernelFunction::makeFromBoxedKernel( BoxedKernel boxed_fn) { return KernelFunction( @@ -317,4 +330,38 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { std::forward(lambda))); } +inline bool KernelToken::isValid() const { + return !invalid_.load(std::memory_order_acquire); +} + +inline void KernelToken::invalidate() { + invalid_.store(true, std::memory_order_release); +} + +inline SafeKernelFunction::SafeKernelFunction( + const KernelFunction* kernel, + std::string debug, + std::shared_ptr opHandle) + : kernel_(kernel ? *kernel : KernelFunction()), + token_(std::make_shared()), + debug_(std::move(debug)), + opHandle_(std::move(opHandle)) { + // Register the token with the original kernel so it gets invalidated when the + // kernel is destroyed + if (kernel) { + kernel->registerToken(token_); + } +} + +inline void SafeKernelFunction::callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + TORCH_CHECK( + token_ && token_->isValid(), + "SafeKernelFunction has been invalidated ", + debug_); + kernel_.callBoxed(opHandle, dispatchKeySet, stack); +} + } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index bc043df6a93e..43eb0028c70f 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -487,6 +487,10 @@ class TORCH_API OperatorHandle { return operatorDef_->op.hasComputedKernelForDispatchKey(k); } + SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const { + return operatorDef_->op.getComputedKernelForDispatchKey(k); + } + std::string dumpComputedTable() const { return operatorDef_->op.dumpComputedTable(); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index b4063fb720be..c172e9b9c609 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -315,6 +315,42 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat return nullptr; } +SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey( + DispatchKey k) const { + TORCH_CHECK( + !isAliasDispatchKey(k), + "Alias keys do not have runtime kernel registrations."); + const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k); + TORCH_CHECK( + dispatchTable_[dispatch_ix].isValid(), + "no kernel for ", + k, + " for ", + name_); + + // Get the KernelFunction object from kernels_ to pass to SafeKernelFunction + + // The KernelFunction object in dispatchTable_ is a copy of the KernelFunction + // in the AnnotatedKernel in kernels_. A KernelFunction is only truly + // deregistered when the kernel is removed from kernels_. However, the + // KernelFunction in dispatchTable_ might be removed before it is deregistered + // (when a newer kernel is registered). Therefore, here we want to return a + // SafeKernelFunction that is backed by the original KernelFunction in + // kernels_, so that we only invalidate it when the kernel is deregistered. + auto [annotatedKernel, _] = + computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); + + // Use findSchemaOrThrow to get OpHandle for the OperatorEntry + auto& dispatcher = c10::Dispatcher::singleton(); + auto opHandle = dispatcher.findSchemaOrThrow( + name_.name.c_str(), name_.overload_name.c_str()); + + return SafeKernelFunction( + &annotatedKernel.kernel, + annotatedKernel.debug, + std::make_shared(opHandle)); +} + const std::vector& OperatorEntry::getTags() const { #if defined C10_MOBILE TORCH_CHECK(false, "tags are not saved for Mobile"); diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 83200ff9c94f..59b54ce1d9d3 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -217,6 +217,8 @@ class TORCH_API OperatorEntry final { const KernelFunction& kernelForDispatchKey(DispatchKey k) const; // Returns true if the "computed table" has an entry for a particular key. bool hasComputedKernelForDispatchKey(DispatchKey k) const; + // Returns a KernelFunction corresponding to the kernel in dispatchTable + SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const; // Returns all the operator tags added at the time of registration const std::vector& getTags() const; void setReportErrorCallback_(std::unique_ptr callback); diff --git a/docs/source/library.md b/docs/source/library.md index 9d706e2e1080..b31ca95d5b6a 100644 --- a/docs/source/library.md +++ b/docs/source/library.md @@ -56,6 +56,7 @@ via PyTorch's C++ operator registration APIs). .. autofunction:: infer_schema .. autoclass:: torch._library.custom_ops.CustomOpDef :members: set_kernel_enabled +.. autofunction:: get_kernel ``` ## Low-level APIs diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 5a494f548742..491648494f6f 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -11,6 +11,7 @@ import tempfile import typing import unittest +from functools import partial from pathlib import Path from typing import * # noqa: F403 @@ -4156,6 +4157,148 @@ def test_any_output_is_alias_to_input_or_output(self): ) ) + def test_library_get_kernel(self): + """Test registering a custom kernel, using it, then deregistering and verifying error.""" + + # Register a dummy kernel for arange to the CPU key that returns a tensor of ones + def dummy_arange_cpu( + dispatch_keys, + start, + end, + dtype=None, + layout=torch.strided, + device=None, + pin_memory=False, + ): + size = max(0, int(end - start)) + return torch.ones(size, dtype=dtype, device=device) + + with torch.library._scoped_library("aten", "IMPL") as lib: + lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True) + + kernel = torch.library.get_kernel("aten::arange.start", "CPU") + dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU) + result = kernel.call_boxed(dispatch_keys, 0, 5) + + self.assertEqual(result, torch.ones(5)) + + # The kernel should now be invalidated after exiting the scoped_library context + with self.assertRaisesRegex(RuntimeError, "has been invalidated"): + kernel.call_boxed(dispatch_keys, 0, 5) + + def test_library_get_kernel_with_conditional_dispatch(self): + """Test registering a custom kernel with conditional dispatch logic.""" + + def conditional_arange_cpu1( + original_kernel, + dispatch_keys, + start, + end, + dtype=None, + layout=torch.strided, + device=None, + pin_memory=False, + ): + # If end is even, use the original kernel, otherwise return ones tensor + if end % 2 == 0: + op_handle = torch.ops.aten.arange.start._handle + return original_kernel.call_boxed( + dispatch_keys, + start, + end, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + else: + size = max(0, int(end - start)) + return torch.ones(size, dtype=dtype, device=device) + + def conditional_arange_cpu2( + original_kernel, + dispatch_keys, + start, + end, + dtype=None, + layout=torch.strided, + device=None, + pin_memory=False, + ): + # If start is even, use the original kernel, otherwise return twos tensor + if start % 2 == 0: + op_handle = torch.ops.aten.arange.start._handle + return original_kernel.call_boxed( + dispatch_keys, + start, + end, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + else: + size = max(0, int(end - start)) + return torch.empty(size, dtype=dtype, device=device).fill_(2) + + original_kernel = torch.library.get_kernel("aten::arange.start", "CPU") + expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6) + expected_result3, expected_result4, expected_result5 = ( + torch.ones(5), + torch.arange(0, 6), + torch.ones(5).fill_(2), + ) + + with torch.library._scoped_library("aten", "IMPL") as lib2: + with torch.library._scoped_library("aten", "IMPL") as lib1: + lib1.impl( + "arange.start", + partial(conditional_arange_cpu1, original_kernel), + "CPU", + with_keyset=True, + ) + + self.assertEqual(torch.arange(0, 5), expected_result1) + self.assertEqual(torch.arange(0, 6), expected_result2) + new_original_kernel = torch.library.get_kernel( + "aten::arange.start", "CPU" + ) + lib2.impl( + "arange.start", + partial(conditional_arange_cpu2, new_original_kernel), + "CPU", + allow_override=True, + with_keyset=True, + ) + + self.assertEqual(torch.arange(0, 5), expected_result3) + self.assertEqual(torch.arange(0, 6), expected_result4) + self.assertEqual(torch.arange(1, 6), expected_result5) + + # The kernel should now be invalidated after destroying lib1 + with self.assertRaisesRegex(RuntimeError, "has been invalidated"): + torch.arange(0, 5) + + # Should still work after destroying lib1 + self.assertEqual(torch.arange(1, 6), expected_result5) + + def test_library_get_kernel_invalid(self): + """Test that get_kernel raises an error when no kernel is available.""" + with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib: + lib.define("cpu_only_op(Tensor x) -> Tensor") + lib.impl("cpu_only_op", lambda x: x * 2, "CPU") + + cpu_kernel = torch.library.get_kernel( + "test_invalid_kernel::cpu_only_op", "CPU" + ) + self.assertIsNotNone(cpu_kernel) + + # CUDA should fail at the isValid() check since no CUDA kernel exists + with self.assertRaisesRegex( + RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op" + ): + torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA") + class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9e03c7dba830..8e97d7e3e3ab 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1693,6 +1693,11 @@ class _DispatchModule: _after_ADInplaceOrView_keyset: DispatchKeySet _after_autograd_keyset: DispatchKeySet +class _SafeKernelFunction: + def call_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ... + @property + def op_handle(self) -> _DispatchOperatorHandle: ... + def _dispatch_library( kind: str, name: str, @@ -1730,6 +1735,10 @@ def _dispatch_has_computed_kernel_for_dispatch_key( name: str, dispatch: _dispatchkey, ) -> _bool: ... +def _dispatch_get_computed_kernel_for_dispatch_key( + name: str, + dispatch: _dispatchkey, +) -> _SafeKernelFunction: ... def _dispatch_find_dangling_impls() -> list[str]: ... def _dispatch_get_all_op_names() -> list[str]: ... def _dispatch_tls_set_dispatch_key_excluded( diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 019ce2070634..568d9402140d 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -602,6 +602,43 @@ void initDispatchBindings(PyObject* module) { c10::parseDispatchKey(dispatch)); }); + // Bind SafeKernelFunction class + py::class_(m, "_SafeKernelFunction") + .def( + "call_boxed", + [](const c10::SafeKernelFunction& self, + c10::DispatchKeySet keyset, + py::args args, + const py::kwargs& kwargs) { + const auto& op = self.opHandle(); + auto stack = torch::jit::createStackForSchema( + op.schema(), + std::move(args), + kwargs, + /*self=*/std::nullopt); + self.callBoxed(op, keyset, &stack); + return torch::jit::createPyObjectForStack(std::move(stack)); + }) + .def( + "__repr__", + [](const c10::SafeKernelFunction& self) { + return "SafeKernelFunction(debug='" + self.debug() + "')"; + }) + .def_property_readonly( + "op_handle", [](const c10::SafeKernelFunction& self) -> py::object { + return py::cast(self.opHandle()); + }); + + m.def( + "_dispatch_get_computed_kernel_for_dispatch_key", + [](const char* name, + c10::DispatchKey dispatch) -> c10::SafeKernelFunction { + auto op = + c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); + TORCH_CHECK(op, "operator ", name, " does not exist"); + return op->getComputedKernelForDispatchKey(dispatch); + }); + m.def("_dispatch_find_dangling_impls", []() -> std::vector { auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); diff --git a/torch/library.py b/torch/library.py index f24c3fbd4276..bbdaebea9521 100644 --- a/torch/library.py +++ b/torch/library.py @@ -45,6 +45,7 @@ "register_torch_dispatch", "register_vmap", "get_ctx", + "get_kernel", "custom_op", "triton_op", "wrap_triton", @@ -1475,6 +1476,80 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": return torch._library.fake_impl.global_ctx_getter() +def get_kernel( + op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey] +) -> torch._C._SafeKernelFunction: + """Returns the computed kernel for a given operator and dispatch key. + + This function retrieves the kernel that would be executed for a given + operator and dispatch key combination. The returned SafeKernelFunction + can be used to call the kernel in a boxed fashion. The intended use + case for this function is to retrieve the original kernel for a given + dispatch key and then register another kernel to the same dispatch key + that calls into the original kernel for certain cases. + + Args: + op: Operator name (along with the overload) or OpOverload object + Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef. + dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for. + Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value. + + Returns: + torch._C._SafeKernelFunction: A safe kernel function that can be used to + call the kernel. + + Raises: + RuntimeError: If the operator does not exist. + + Example: + >>> # Get the CPU kernel for torch.add + >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU") + >>> + >>> # You can also use DispatchKey enum + >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU) + >>> + >>> # Or use an OpOverload directly + >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU") + >>> + >>> # Example: Using get_kernel in a custom op with conditional dispatch + >>> # Get the original kernel for torch.sin + >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU") + >>> + >>> # If input has negative values, use original sin, otherwise return zeros + >>> def conditional_sin_impl(dispatch_keys, x): + >>> if (x < 0).any(): + >>> return original_sin_kernel.call_boxed(dispatch_keys, x) + >>> else: + >>> return torch.zeros_like(x) + >>> + >>> lib = torch.library.Library("aten", "IMPL") + >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet + >>> which needs to be the first argument to ``kernel.call_boxed`` + >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True) + >>> + >>> # Test the conditional behavior + >>> x_positive = torch.tensor([1.0, 2.0]) + >>> x_mixed = torch.tensor([-1.0, 2.0]) + >>> torch.sin(x_positive) + tensor([0., 0.]) + >>> torch.sin(x_mixed) + tensor([-0.8415, 0.9093]) + """ + if not isinstance(op, (str, torch._ops.OpOverload)): + raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}") + + if isinstance(op, torch._ops.OpOverload): + op = op._name + + if isinstance(dispatch_key, str): + try: + dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] + except KeyError: + raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None + + return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key) + + _OPCHECK_DEFAULT_UTILS = ( "test_schema", "test_autograd_registration",