Skip to content

Add utility to get computed kernel in torch.library #158393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: gh/mikaylagawarecki/320/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions aten/src/ATen/core/boxing/KernelFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <c10/core/DispatchKeySet.h>
#include <c10/util/TypeList.h>
#include <c10/util/intrusive_ptr.h>
#include <atomic>
#include <memory>
#include <type_traits>

namespace c10 {
Expand All @@ -17,6 +19,9 @@ class OperatorHandle;
struct OperatorKernel;
class KernelFunction;

class KernelToken;
class SafeKernelFunction;

template <typename T>
using has_symint = std::disjunction<
std::is_same<c10::SymInt, T>,
Expand Down Expand Up @@ -90,6 +95,12 @@ class TORCH_API KernelFunction final {
BoxedKernel::BoxedKernelFunction_withDispatchKeys;

KernelFunction();
~KernelFunction();

KernelFunction(const KernelFunction&) = default;
KernelFunction& operator=(const KernelFunction&) = default;

KernelFunction(KernelFunction&&) noexcept = default;

// Fast path for dispatch to allow not touching the boxed kernel in
// the common case where unboxed is available.
Expand Down Expand Up @@ -262,6 +273,13 @@ class TORCH_API KernelFunction final {
// For testing internal invariants only
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;

// Register a token to be invalidated when this KernelFunction is destroyed
void registerToken(std::weak_ptr<KernelToken> token) const;

// List of tokens that need to be invalidated when this KernelFunction is
// destroyed
mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why mutable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why weak_ptr and not shared_ptr?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why mutable

I think this is necessary in order to make registerToken const, which was in turn needed to allow SafeKernelFunction to take in const KernelFunction*, removing this would necessitate const_cast-ing the annotatedKernel.kernel in getComputedKernelForDispatchKey, wdyt

why weak_ptr and not shared_ptr

What's the benefit of shared_ptr over weak_ptr here? if we use weak_ptr, the KernelToken dies with the SafeKernelFunction, which I think is what we want to achieve here (?)


private:
explicit KernelFunction(
std::unique_ptr<OperatorKernel> functor,
Expand All @@ -278,6 +296,47 @@ class TORCH_API KernelFunction final {
void* sym_unboxed_kernel_func_;
};

// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
// destroyed
class KernelToken {
public:
bool isValid() const;
void invalidate();

private:
std::atomic<bool> invalid_{false};
};

class SafeKernelFunction {
public:
SafeKernelFunction(
const KernelFunction* kernel,
std::string debug,
std::shared_ptr<OperatorHandle> opHandle);

// Safe callBoxed - checks token validity first
void callBoxed(
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
Stack* stack) const;

// Get debug information
const std::string& debug() const {
return debug_;
}

// Get the OpHandle that lives on this SafeKernelFunction
const OperatorHandle& opHandle() const {
return *opHandle_;
}

private:
KernelFunction kernel_;
std::shared_ptr<KernelToken> token_;
std::string debug_;
std::shared_ptr<OperatorHandle> opHandle_;
};

} // namespace c10

#include <ATen/core/boxing/KernelFunction_impl.h>
47 changes: 47 additions & 0 deletions aten/src/ATen/core/boxing/KernelFunction_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ inline KernelFunction::KernelFunction()
unboxed_kernel_func_(nullptr),
sym_unboxed_kernel_func_(nullptr) {}

inline KernelFunction::~KernelFunction() {
for (auto& weak_token : tokens_) {
if (auto token = weak_token.lock()) {
token->invalidate();
}
}
}

inline KernelFunction::KernelFunction(
std::unique_ptr<OperatorKernel> functor,
InternalBoxedKernelFunction* boxed_kernel_func,
Expand Down Expand Up @@ -157,6 +165,11 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
std::forward<Args>(args)...);
}

inline void KernelFunction::registerToken(
std::weak_ptr<KernelToken> token) const {
tokens_.push_back(std::move(token));
}

inline KernelFunction KernelFunction::makeFromBoxedKernel(
BoxedKernel boxed_fn) {
return KernelFunction(
Expand Down Expand Up @@ -317,4 +330,38 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
std::forward<Lambda>(lambda)));
}

inline bool KernelToken::isValid() const {
return !invalid_.load(std::memory_order_acquire);
}

inline void KernelToken::invalidate() {
invalid_.store(true, std::memory_order_release);
}

inline SafeKernelFunction::SafeKernelFunction(
const KernelFunction* kernel,
std::string debug,
std::shared_ptr<OperatorHandle> opHandle)
: kernel_(kernel ? *kernel : KernelFunction()),
token_(std::make_shared<KernelToken>()),
debug_(std::move(debug)),
opHandle_(std::move(opHandle)) {
// Register the token with the original kernel so it gets invalidated when the
// kernel is destroyed
if (kernel) {
kernel->registerToken(token_);
}
}

inline void SafeKernelFunction::callBoxed(
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
Stack* stack) const {
TORCH_CHECK(
token_ && token_->isValid(),
"SafeKernelFunction has been invalidated ",
debug_);
kernel_.callBoxed(opHandle, dispatchKeySet, stack);
}

} // namespace c10
4 changes: 4 additions & 0 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,10 @@ class TORCH_API OperatorHandle {
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
}

SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
return operatorDef_->op.getComputedKernelForDispatchKey(k);
}

std::string dumpComputedTable() const {
return operatorDef_->op.dumpComputedTable();
}
Expand Down
36 changes: 36 additions & 0 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,42 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
return nullptr;
}

SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
DispatchKey k) const {
TORCH_CHECK(
!isAliasDispatchKey(k),
"Alias keys do not have runtime kernel registrations.");
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
TORCH_CHECK(
dispatchTable_[dispatch_ix].isValid(),
"no kernel for ",
k,
" for ",
name_);

// Get the KernelFunction object from kernels_ to pass to SafeKernelFunction

// The KernelFunction object in dispatchTable_ is a copy of the KernelFunction
// in the AnnotatedKernel in kernels_. A KernelFunction is only truly
// deregistered when the kernel is removed from kernels_. However, the
// KernelFunction in dispatchTable_ might be removed before it is deregistered
// (when a newer kernel is registered). Therefore, here we want to return a
// SafeKernelFunction that is backed by the original KernelFunction in
// kernels_, so that we only invalidate it when the kernel is deregistered.
auto [annotatedKernel, _] =
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);

// Use findSchemaOrThrow to get OpHandle for the OperatorEntry
auto& dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findSchemaOrThrow(
name_.name.c_str(), name_.overload_name.c_str());

return SafeKernelFunction(
&annotatedKernel.kernel,
annotatedKernel.debug,
std::make_shared<OperatorHandle>(opHandle));
}

const std::vector<at::Tag>& OperatorEntry::getTags() const {
#if defined C10_MOBILE
TORCH_CHECK(false, "tags are not saved for Mobile");
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class TORCH_API OperatorEntry final {
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
// Returns true if the "computed table" has an entry for a particular key.
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
// Returns a KernelFunction corresponding to the kernel in dispatchTable
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
Expand Down
1 change: 1 addition & 0 deletions docs/source/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ via PyTorch's C++ operator registration APIs).
.. autofunction:: infer_schema
.. autoclass:: torch._library.custom_ops.CustomOpDef
:members: set_kernel_enabled
.. autofunction:: get_kernel
```

## Low-level APIs
Expand Down
143 changes: 143 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tempfile
import typing
import unittest
from functools import partial
from pathlib import Path
from typing import * # noqa: F403

Expand Down Expand Up @@ -4156,6 +4157,148 @@ def test_any_output_is_alias_to_input_or_output(self):
)
)

def test_library_get_kernel(self):
"""Test registering a custom kernel, using it, then deregistering and verifying error."""

# Register a dummy kernel for arange to the CPU key that returns a tensor of ones
def dummy_arange_cpu(
dispatch_keys,
start,
end,
dtype=None,
layout=torch.strided,
device=None,
pin_memory=False,
):
size = max(0, int(end - start))
return torch.ones(size, dtype=dtype, device=device)

with torch.library._scoped_library("aten", "IMPL") as lib:
lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True)

kernel = torch.library.get_kernel("aten::arange.start", "CPU")
dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
result = kernel.call_boxed(dispatch_keys, 0, 5)

self.assertEqual(result, torch.ones(5))

# The kernel should now be invalidated after exiting the scoped_library context
with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
kernel.call_boxed(dispatch_keys, 0, 5)

def test_library_get_kernel_with_conditional_dispatch(self):
"""Test registering a custom kernel with conditional dispatch logic."""

def conditional_arange_cpu1(
original_kernel,
dispatch_keys,
start,
end,
dtype=None,
layout=torch.strided,
device=None,
pin_memory=False,
):
# If end is even, use the original kernel, otherwise return ones tensor
if end % 2 == 0:
op_handle = torch.ops.aten.arange.start._handle
return original_kernel.call_boxed(
dispatch_keys,
start,
end,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
)
else:
size = max(0, int(end - start))
return torch.ones(size, dtype=dtype, device=device)

def conditional_arange_cpu2(
original_kernel,
dispatch_keys,
start,
end,
dtype=None,
layout=torch.strided,
device=None,
pin_memory=False,
):
# If start is even, use the original kernel, otherwise return twos tensor
if start % 2 == 0:
op_handle = torch.ops.aten.arange.start._handle
return original_kernel.call_boxed(
dispatch_keys,
start,
end,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
)
else:
size = max(0, int(end - start))
return torch.empty(size, dtype=dtype, device=device).fill_(2)

original_kernel = torch.library.get_kernel("aten::arange.start", "CPU")
expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6)
expected_result3, expected_result4, expected_result5 = (
torch.ones(5),
torch.arange(0, 6),
torch.ones(5).fill_(2),
)

with torch.library._scoped_library("aten", "IMPL") as lib2:
with torch.library._scoped_library("aten", "IMPL") as lib1:
lib1.impl(
"arange.start",
partial(conditional_arange_cpu1, original_kernel),
"CPU",
with_keyset=True,
)

self.assertEqual(torch.arange(0, 5), expected_result1)
self.assertEqual(torch.arange(0, 6), expected_result2)
new_original_kernel = torch.library.get_kernel(
"aten::arange.start", "CPU"
)
lib2.impl(
"arange.start",
partial(conditional_arange_cpu2, new_original_kernel),
"CPU",
allow_override=True,
with_keyset=True,
)

self.assertEqual(torch.arange(0, 5), expected_result3)
self.assertEqual(torch.arange(0, 6), expected_result4)
self.assertEqual(torch.arange(1, 6), expected_result5)

# The kernel should now be invalidated after destroying lib1
with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
torch.arange(0, 5)

# Should still work after destroying lib1
self.assertEqual(torch.arange(1, 6), expected_result5)

def test_library_get_kernel_invalid(self):
"""Test that get_kernel raises an error when no kernel is available."""
with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib:
lib.define("cpu_only_op(Tensor x) -> Tensor")
lib.impl("cpu_only_op", lambda x: x * 2, "CPU")

cpu_kernel = torch.library.get_kernel(
"test_invalid_kernel::cpu_only_op", "CPU"
)
self.assertIsNotNone(cpu_kernel)

# CUDA should fail at the isValid() check since no CUDA kernel exists
with self.assertRaisesRegex(
RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op"
):
torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA")


class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"
Expand Down
Loading
Loading