Skip to content

IntraKernel Dispatcher #155330

@drisspg

Description

@drisspg

New Feature Pitch

Often people write cool, super fast kernels that overlap with existing implementations in PyTorch. However PyTorch tensors can have a very large configuration space -> non-contiguous, nonzero offset, arbitrary dims, etc.. Often kernel authors are able to make their implementations so fast because they focus on a specific subset of the configuration space.

The pitch is that we should design a nice extension mechanism for people to "bring their own kernels" to PyTorch while also being able to fall back to PyTorch's base implementation for the cases their kernels don't support. I have coined this "intra kernel dispatching".

Rough Design

This would let users do something like

  dispatcher = AtenIntraKernelDispatcher()

  # Example 1: Override mm for small square matrices
  def is_small_square(self: torch.Tensor, mat2: torch.Tensor) -> bool:
      return (self.dim() == 2 and mat2.dim() == 2 and
              self.size(0) == self.size(1) == mat2.size(0) == mat2.size(1) <= 64)

  def custom_mm_kernel(self: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
      print(f"[CUSTOM MM] {self.shape} @ {mat2.shape}")
      # Your optimized implementation here
      return torch.matmul(self, mat2)

  dispatcher.register("mm", is_small_square, custom_mm_kernel)

  print("Testing mm dispatch:")

  # Should use custom kernel
  print("\n1. Small square (32x32):")
  a = torch.randn(32, 32)
  b = torch.randn(32, 32)
  c = torch.mm(a, b)

This would allow users to build Libraries of external kernels, incrementally. They can be loaded to by PT users and git instant speedups. Similar ideas can be seen: https://github.com/FlagOpen/FlagGems

A Few existing Problems

I was looking at this with @albanD, there are a few problems today. The main one is once you override your op there is no nice way to get the original implementation. The implementation would check that there is a computed_key. If there is an existing compute key, then there is infact something to override. However we don't expose the ability today to get the KernelFunction for computed keys like we do for non computed keys.

We would then need to pybind kernelFunction's call boxed to use as the fall through in the dispatcher.

Besides all of the work above the other problem:
Libraries in PyTorch is responsible for the lifetime of the raw function pointer. Described in the mermaid doc below. It is possible to have

Lib1 -> Lib2 -> Lib3 where Lib3 stores a reference to a original_implementation that is based off a raw function in Lib2. If Lib2 gets destructed then raw_function has no owner and this can lead to use after free. Initial option is to check that previous Lib and or all of its dependencies are still alive before calling raw func

  graph TD
      Dispatcher --> |owns| OperatorHandle
      OperatorHandle --> |owns| OperatorEntry
      OperatorEntry --> |owns| KernelFunction

      CppLibrary[C++ Library manages lifetime] --> |owns raw function pointer| KernelFunction
      PythonLibrary[Python Library torch.library.py] --> |Python wrapper around| CppLibrary

      CppLibrary -.-> |deregistration can invalidate pointer| KernelFunction

      classDef cppNode fill:#e3f2fd,stroke:#1976d2
      classDef pythonNode fill:#ffebee,stroke:#d32f2f
      classDef problemNode fill:#fff3e0,stroke:#f57c00

      class Dispatcher,OperatorHandle,OperatorEntry,KernelFunction,CppLibrary cppNode
      class PythonLibrary pythonNode
      class CppLibrary problemNode
Loading

cc @msaroufim on shipping Popcorn

Metadata

Metadata

Assignees

No one assigned

    Labels

    topic: new featurestopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions