-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
New Feature Pitch
Often people write cool, super fast kernels that overlap with existing implementations in PyTorch. However PyTorch tensors can have a very large configuration space -> non-contiguous, nonzero offset, arbitrary dims, etc.. Often kernel authors are able to make their implementations so fast because they focus on a specific subset of the configuration space.
The pitch is that we should design a nice extension mechanism for people to "bring their own kernels" to PyTorch while also being able to fall back to PyTorch's base implementation for the cases their kernels don't support. I have coined this "intra kernel dispatching".
Rough Design
This would let users do something like
dispatcher = AtenIntraKernelDispatcher()
# Example 1: Override mm for small square matrices
def is_small_square(self: torch.Tensor, mat2: torch.Tensor) -> bool:
return (self.dim() == 2 and mat2.dim() == 2 and
self.size(0) == self.size(1) == mat2.size(0) == mat2.size(1) <= 64)
def custom_mm_kernel(self: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
print(f"[CUSTOM MM] {self.shape} @ {mat2.shape}")
# Your optimized implementation here
return torch.matmul(self, mat2)
dispatcher.register("mm", is_small_square, custom_mm_kernel)
print("Testing mm dispatch:")
# Should use custom kernel
print("\n1. Small square (32x32):")
a = torch.randn(32, 32)
b = torch.randn(32, 32)
c = torch.mm(a, b)
This would allow users to build Libraries of external kernels, incrementally. They can be loaded to by PT users and git instant speedups. Similar ideas can be seen: https://github.com/FlagOpen/FlagGems
A Few existing Problems
I was looking at this with @albanD, there are a few problems today. The main one is once you override your op there is no nice way to get the original implementation. The implementation would check that there is a computed_key. If there is an existing compute key, then there is infact something to override. However we don't expose the ability today to get the KernelFunction for computed keys like we do for non computed keys.
We would then need to pybind kernelFunction's call boxed to use as the fall through in the dispatcher.
Besides all of the work above the other problem:
Libraries in PyTorch is responsible for the lifetime of the raw function pointer. Described in the mermaid doc below. It is possible to have
Lib1 -> Lib2 -> Lib3 where Lib3 stores a reference to a original_implementation
that is based off a raw function in Lib2. If Lib2 gets destructed then raw_function has no owner and this can lead to use after free. Initial option is to check that previous Lib and or all of its dependencies are still alive before calling raw func
graph TD
Dispatcher --> |owns| OperatorHandle
OperatorHandle --> |owns| OperatorEntry
OperatorEntry --> |owns| KernelFunction
CppLibrary[C++ Library manages lifetime] --> |owns raw function pointer| KernelFunction
PythonLibrary[Python Library torch.library.py] --> |Python wrapper around| CppLibrary
CppLibrary -.-> |deregistration can invalidate pointer| KernelFunction
classDef cppNode fill:#e3f2fd,stroke:#1976d2
classDef pythonNode fill:#ffebee,stroke:#d32f2f
classDef problemNode fill:#fff3e0,stroke:#f57c00
class Dispatcher,OperatorHandle,OperatorEntry,KernelFunction,CppLibrary cppNode
class PythonLibrary pythonNode
class CppLibrary problemNode
cc @msaroufim on shipping Popcorn