Skip to content

Commit b1edaa6

Browse files
Add utility to get computed kernel in torch.library
ghstack-source-id: 502dccd Pull Request resolved: #158393
1 parent fc340d0 commit b1edaa6

File tree

10 files changed

+416
-0
lines changed

10 files changed

+416
-0
lines changed

aten/src/ATen/core/boxing/KernelFunction.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
#include <ATen/core/ATen_fwd.h>
44
#include <ATen/core/boxing/BoxedKernel.h>
5+
#include <ATen/core/function_schema.h>
56
#include <ATen/core/stack.h>
67
#include <c10/core/DispatchKeySet.h>
78
#include <c10/util/TypeList.h>
89
#include <c10/util/intrusive_ptr.h>
10+
#include <atomic>
11+
#include <memory>
912
#include <type_traits>
1013

1114
namespace c10 {
@@ -17,6 +20,9 @@ class OperatorHandle;
1720
struct OperatorKernel;
1821
class KernelFunction;
1922

23+
class KernelToken;
24+
class SafeKernelFunction;
25+
2026
template <typename T>
2127
using has_symint = std::disjunction<
2228
std::is_same<c10::SymInt, T>,
@@ -90,6 +96,12 @@ class TORCH_API KernelFunction final {
9096
BoxedKernel::BoxedKernelFunction_withDispatchKeys;
9197

9298
KernelFunction();
99+
~KernelFunction();
100+
101+
KernelFunction(const KernelFunction&) = default;
102+
KernelFunction& operator=(const KernelFunction&) = default;
103+
104+
KernelFunction(KernelFunction&&) noexcept = default;
93105

94106
// Fast path for dispatch to allow not touching the boxed kernel in
95107
// the common case where unboxed is available.
@@ -262,6 +274,13 @@ class TORCH_API KernelFunction final {
262274
// For testing internal invariants only
263275
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
264276

277+
// Register a token to be invalidated when this KernelFunction is destroyed
278+
void registerToken(std::weak_ptr<KernelToken> token) const;
279+
280+
// List of tokens that need to be invalidated when this KernelFunction is
281+
// destroyed
282+
mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
283+
265284
private:
266285
explicit KernelFunction(
267286
std::unique_ptr<OperatorKernel> functor,
@@ -278,6 +297,47 @@ class TORCH_API KernelFunction final {
278297
void* sym_unboxed_kernel_func_;
279298
};
280299

300+
// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
301+
// destroyed
302+
class KernelToken {
303+
public:
304+
bool isValid() const;
305+
void invalidate();
306+
307+
private:
308+
std::atomic<bool> invalid_{false};
309+
};
310+
311+
class SafeKernelFunction {
312+
public:
313+
SafeKernelFunction(
314+
const KernelFunction* kernel,
315+
std::string debug,
316+
std::shared_ptr<OperatorHandle> opHandle);
317+
318+
// Safe callBoxed - checks token validity first
319+
void callBoxed(
320+
const OperatorHandle& opHandle,
321+
DispatchKeySet dispatchKeySet,
322+
Stack* stack) const;
323+
324+
// Get debug information
325+
const std::string& debug() const {
326+
return debug_;
327+
}
328+
329+
// Get the OpHandle that lives on this SafeKernelFunction
330+
const OperatorHandle& opHandle() const {
331+
return *opHandle_;
332+
}
333+
334+
private:
335+
KernelFunction kernel_;
336+
std::shared_ptr<KernelToken> token_;
337+
std::string debug_;
338+
std::shared_ptr<OperatorHandle> opHandle_;
339+
};
340+
281341
} // namespace c10
282342

283343
#include <ATen/core/boxing/KernelFunction_impl.h>

aten/src/ATen/core/boxing/KernelFunction_impl.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ inline KernelFunction::KernelFunction()
2424
unboxed_kernel_func_(nullptr),
2525
sym_unboxed_kernel_func_(nullptr) {}
2626

27+
inline KernelFunction::~KernelFunction() {
28+
for (auto& weak_token : tokens_) {
29+
if (auto token = weak_token.lock()) {
30+
token->invalidate();
31+
}
32+
}
33+
}
34+
2735
inline KernelFunction::KernelFunction(
2836
std::unique_ptr<OperatorKernel> functor,
2937
InternalBoxedKernelFunction* boxed_kernel_func,
@@ -157,6 +165,11 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
157165
std::forward<Args>(args)...);
158166
}
159167

168+
inline void KernelFunction::registerToken(
169+
std::weak_ptr<KernelToken> token) const {
170+
tokens_.push_back(std::move(token));
171+
}
172+
160173
inline KernelFunction KernelFunction::makeFromBoxedKernel(
161174
BoxedKernel boxed_fn) {
162175
return KernelFunction(
@@ -317,4 +330,38 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
317330
std::forward<Lambda>(lambda)));
318331
}
319332

333+
inline bool KernelToken::isValid() const {
334+
return !invalid_.load(std::memory_order_acquire);
335+
}
336+
337+
inline void KernelToken::invalidate() {
338+
invalid_.store(true, std::memory_order_release);
339+
}
340+
341+
inline SafeKernelFunction::SafeKernelFunction(
342+
const KernelFunction* kernel,
343+
std::string debug,
344+
std::shared_ptr<OperatorHandle> opHandle)
345+
: kernel_(kernel ? *kernel : KernelFunction()),
346+
token_(std::make_shared<KernelToken>()),
347+
debug_(std::move(debug)),
348+
opHandle_(std::move(opHandle)) {
349+
// Register the token with the original kernel so it gets invalidated when the
350+
// kernel is destroyed
351+
if (kernel) {
352+
kernel->registerToken(token_);
353+
}
354+
}
355+
356+
inline void SafeKernelFunction::callBoxed(
357+
const OperatorHandle& opHandle,
358+
DispatchKeySet dispatchKeySet,
359+
Stack* stack) const {
360+
TORCH_CHECK(
361+
token_ && token_->isValid(),
362+
"SafeKernelFunction has been invalidated ",
363+
debug_);
364+
kernel_.callBoxed(opHandle, dispatchKeySet, stack);
365+
}
366+
320367
} // namespace c10

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,10 @@ class TORCH_API OperatorHandle {
487487
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
488488
}
489489

490+
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
491+
return operatorDef_->op.getComputedKernelForDispatchKey(k);
492+
}
493+
490494
std::string dumpComputedTable() const {
491495
return operatorDef_->op.dumpComputedTable();
492496
}

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,42 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
315315
return nullptr;
316316
}
317317

318+
SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
319+
DispatchKey k) const {
320+
TORCH_CHECK(
321+
!isAliasDispatchKey(k),
322+
"Alias keys do not have runtime kernel registrations.");
323+
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
324+
TORCH_CHECK(
325+
dispatchTable_[dispatch_ix].isValid(),
326+
"no kernel for ",
327+
k,
328+
" for ",
329+
name_);
330+
331+
// Get the KernelFunction object from kernels_ to pass to SafeKernelFunction
332+
333+
// The KernelFunction object in dispatchTable_ is a copy of the KernelFunction
334+
// in the AnnotatedKernel in kernels_. A KernelFunction is only truly
335+
// deregistered when the kernel is removed from kernels_. However, the
336+
// KernelFunction in dispatchTable_ might be removed before it is deregistered
337+
// (when a newer kernel is registered). Therefore, here we want to return a
338+
// SafeKernelFunction that is backed by the original KernelFunction in
339+
// kernels_, so that we only invalidate it when the kernel is deregistered.
340+
auto [annotatedKernel, _] =
341+
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
342+
343+
// Use findSchemaOrThrow to get OpHandle for the OperatorEntry
344+
auto& dispatcher = c10::Dispatcher::singleton();
345+
auto opHandle = dispatcher.findSchemaOrThrow(
346+
name_.name.c_str(), name_.overload_name.c_str());
347+
348+
return SafeKernelFunction(
349+
&annotatedKernel.kernel,
350+
annotatedKernel.debug,
351+
std::make_shared<OperatorHandle>(opHandle));
352+
}
353+
318354
const std::vector<at::Tag>& OperatorEntry::getTags() const {
319355
#if defined C10_MOBILE
320356
TORCH_CHECK(false, "tags are not saved for Mobile");

aten/src/ATen/core/dispatch/OperatorEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ class TORCH_API OperatorEntry final {
217217
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
218218
// Returns true if the "computed table" has an entry for a particular key.
219219
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
220+
// Returns a KernelFunction corresponding to the kernel in dispatchTable
221+
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
220222
// Returns all the operator tags added at the time of registration
221223
const std::vector<at::Tag>& getTags() const;
222224
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);

docs/source/library.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ via PyTorch's C++ operator registration APIs).
5656
.. autofunction:: infer_schema
5757
.. autoclass:: torch._library.custom_ops.CustomOpDef
5858
:members: set_kernel_enabled
59+
.. autofunction:: get_kernel
5960
```
6061

6162
## Low-level APIs

test/test_custom_ops.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
import typing
1313
import unittest
14+
from functools import partial
1415
from pathlib import Path
1516
from typing import * # noqa: F403
1617

@@ -4156,6 +4157,148 @@ def test_any_output_is_alias_to_input_or_output(self):
41564157
)
41574158
)
41584159

4160+
def test_library_get_kernel(self):
4161+
"""Test registering a custom kernel, using it, then deregistering and verifying error."""
4162+
4163+
# Register a dummy kernel for arange to the CPU key that returns a tensor of ones
4164+
def dummy_arange_cpu(
4165+
dispatch_keys,
4166+
start,
4167+
end,
4168+
dtype=None,
4169+
layout=torch.strided,
4170+
device=None,
4171+
pin_memory=False,
4172+
):
4173+
size = max(0, int(end - start))
4174+
return torch.ones(size, dtype=dtype, device=device)
4175+
4176+
with torch.library._scoped_library("aten", "IMPL") as lib:
4177+
lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True)
4178+
4179+
kernel = torch.library.get_kernel("aten::arange.start", "CPU")
4180+
dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
4181+
result = kernel.call_boxed(dispatch_keys, 0, 5)
4182+
4183+
self.assertEqual(result, torch.ones(5))
4184+
4185+
# The kernel should now be invalidated after exiting the scoped_library context
4186+
with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
4187+
kernel.call_boxed(dispatch_keys, 0, 5)
4188+
4189+
def test_library_get_kernel_with_conditional_dispatch(self):
4190+
"""Test registering a custom kernel with conditional dispatch logic."""
4191+
4192+
def conditional_arange_cpu1(
4193+
original_kernel,
4194+
dispatch_keys,
4195+
start,
4196+
end,
4197+
dtype=None,
4198+
layout=torch.strided,
4199+
device=None,
4200+
pin_memory=False,
4201+
):
4202+
# If end is even, use the original kernel, otherwise return ones tensor
4203+
if end % 2 == 0:
4204+
op_handle = torch.ops.aten.arange.start._handle
4205+
return original_kernel.call_boxed(
4206+
dispatch_keys,
4207+
start,
4208+
end,
4209+
dtype=dtype,
4210+
layout=layout,
4211+
device=device,
4212+
pin_memory=pin_memory,
4213+
)
4214+
else:
4215+
size = max(0, int(end - start))
4216+
return torch.ones(size, dtype=dtype, device=device)
4217+
4218+
def conditional_arange_cpu2(
4219+
original_kernel,
4220+
dispatch_keys,
4221+
start,
4222+
end,
4223+
dtype=None,
4224+
layout=torch.strided,
4225+
device=None,
4226+
pin_memory=False,
4227+
):
4228+
# If start is even, use the original kernel, otherwise return twos tensor
4229+
if start % 2 == 0:
4230+
op_handle = torch.ops.aten.arange.start._handle
4231+
return original_kernel.call_boxed(
4232+
dispatch_keys,
4233+
start,
4234+
end,
4235+
dtype=dtype,
4236+
layout=layout,
4237+
device=device,
4238+
pin_memory=pin_memory,
4239+
)
4240+
else:
4241+
size = max(0, int(end - start))
4242+
return torch.empty(size, dtype=dtype, device=device).fill_(2)
4243+
4244+
original_kernel = torch.library.get_kernel("aten::arange.start", "CPU")
4245+
expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6)
4246+
expected_result3, expected_result4, expected_result5 = (
4247+
torch.ones(5),
4248+
torch.arange(0, 6),
4249+
torch.ones(5).fill_(2),
4250+
)
4251+
4252+
with torch.library._scoped_library("aten", "IMPL") as lib2:
4253+
with torch.library._scoped_library("aten", "IMPL") as lib1:
4254+
lib1.impl(
4255+
"arange.start",
4256+
partial(conditional_arange_cpu1, original_kernel),
4257+
"CPU",
4258+
with_keyset=True,
4259+
)
4260+
4261+
self.assertEqual(torch.arange(0, 5), expected_result1)
4262+
self.assertEqual(torch.arange(0, 6), expected_result2)
4263+
new_original_kernel = torch.library.get_kernel(
4264+
"aten::arange.start", "CPU"
4265+
)
4266+
lib2.impl(
4267+
"arange.start",
4268+
partial(conditional_arange_cpu2, new_original_kernel),
4269+
"CPU",
4270+
allow_override=True,
4271+
with_keyset=True,
4272+
)
4273+
4274+
self.assertEqual(torch.arange(0, 5), expected_result3)
4275+
self.assertEqual(torch.arange(0, 6), expected_result4)
4276+
self.assertEqual(torch.arange(1, 6), expected_result5)
4277+
4278+
# The kernel should now be invalidated after destroying lib1
4279+
with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
4280+
torch.arange(0, 5)
4281+
4282+
# Should still work after destroying lib1
4283+
self.assertEqual(torch.arange(1, 6), expected_result5)
4284+
4285+
def test_library_get_kernel_invalid(self):
4286+
"""Test that get_kernel raises an error when no kernel is available."""
4287+
with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib:
4288+
lib.define("cpu_only_op(Tensor x) -> Tensor")
4289+
lib.impl("cpu_only_op", lambda x: x * 2, "CPU")
4290+
4291+
cpu_kernel = torch.library.get_kernel(
4292+
"test_invalid_kernel::cpu_only_op", "CPU"
4293+
)
4294+
self.assertIsNotNone(cpu_kernel)
4295+
4296+
# CUDA should fail at the isValid() check since no CUDA kernel exists
4297+
with self.assertRaisesRegex(
4298+
RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op"
4299+
):
4300+
torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA")
4301+
41594302

41604303
class MiniOpTestOther(CustomOpTestCaseBase):
41614304
test_ns = "mini_op_test"

0 commit comments

Comments
 (0)