123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import re
- import torch._C as C
- """
- PythonDispatcher class is a thin python-binding to C++ dispatcher and it
- is designed to show how dispatcher precompute works. In particular,
- it shows for a certain op `foo`, what the computed dispatch table looks
- like after user register their kernels to certains dispatch keys.
- In the real C++ dispatcher we support many dispatch keys for different
- functionalities. For simplicity PythonDispatcher only supports dispatch
- keys for a single example of each use case. These use cases are listed below:
- - CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
- autograd kernel in pytorch core library.
- E.g. CPU, CUDA
- - FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
- inference kernels, but they share the same autograd kernel specified in AutogradOther.
- E.g. FPGA, SparseCsrCPU
- - XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
- kernel defined in pytorch core library. Backend owner is responsible for registering both
- inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
- E.g. XLA, XPU, MPS
- - CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
- Kernels registered to this key MUST work for inference for all backends.
- - Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
- Kernels registered to this key MUST work for autograd for all backends.
- - CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd
- Kernels registered to this key MUST work for both inference + autograd for all backends.
- Note we only allow registrations to alias keys inside pytorch core library. E.g
- you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd
- kernel from torch-xla extension, instead you should upstream the kernel into
- pytorch/pytorch repo so that it's available for all backends and continuously
- tested even without the extension.
- Usage:
- dispatcher = PythonDispatcher()
- dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"])
- print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
- # For more debugging information
- # print(dispatcher.keys())
- # print(dispatcher.registrations())
- # print(dispatcher.rawRegistrations())
- # print(dispatcher.rawDispatchTable())
- PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
- This file only provides the simplified API for developers, relevant test code is located in
- test/test_dispatch.py
- """
- class PythonDispatcher:
- namespace = "__test__"
- name = "foo"
- # fmt: off
- runtime_keys = [
- "CPU", "AutogradCPU",
- "FPGA", "AutogradOther",
- "XLA", "AutogradXLA",
- "Lazy", "AutogradLazy",
- ]
- # fmt: on
- alias_keys = [
- "CompositeExplicitAutograd",
- "Autograd",
- "CompositeImplicitAutograd",
- ]
- supported_keys = runtime_keys + alias_keys
- def __init__(self):
- C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
- self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
- self.ref.def_("foo(Tensor x) -> Tensor")
- """
- Returns a list of dispatch keys supported by PythonDispatcher.
- You can register kernels to these keys.
- """
- def keys(self):
- return self.supported_keys
- """
- Register kernels to the target dispatchKeys.
- dispatchKeys(list[str]): a list of dispatch keys that you want to register
- your own kernel. Note that you don't need to write the kernel yourself in
- this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
- automatically generated and registered.
- """
- def register(self, dispatchKeys):
- # Overriden is not supported and triggers a warning in C++ dispatcher.
- if len(set(dispatchKeys)) != len(dispatchKeys):
- raise RuntimeError(
- f"Overriden is not allowed but found duplicates in {dispatchKeys}."
- )
- # We currently forbid this in codegen instead of C++ dispatcher.
- if (
- "CompositeImplicitAutograd" in dispatchKeys
- and "CompositeExplicitAutograd" in dispatchKeys
- ):
- raise RuntimeError(
- "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed."
- )
- for key in dispatchKeys:
- if key not in self.supported_keys:
- raise RuntimeError(
- f"{key} is not supported, please select a dispatch key in {self.supported_keys}."
- )
- self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)
- """
- Helper function to format (key, kernel).
- """
- def _format_line(self, key, kernel):
- return "{:<15} {}\n".format(key, kernel)
- """
- Helper function to print a table header.
- """
- def _format_header(self, header):
- s = f"""
- {header}
- """
- s += self._format_line("key", "kernel")
- s += "---------------------------\n"
- return s
- """
- Returns raw output of all registration info for debugging only.
- Use registrations() for a simplified version.
- """
- def rawRegistrations(self):
- return C._dispatch_dump("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined]
- """
- Returns raw output of computed dispatch table for debugging only.
- Use dispatchTable() for a simplified version.
- """
- def rawDispatchTable(self):
- return C._dispatch_dump_table("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined]
- """
- Returns a table(str) including all the registrations from users.
- Note this includes registrations to both runtime keys and alias keys.
- """
- def registrations(self):
- output = self._format_header("Registered Kernels")
- state = self.rawRegistrations()
- state_entries = state.split("\n")
- for line in state_entries:
- first = line.split(":")[0]
- if any(first.startswith(k) for k in self.supported_keys):
- kernel = line.split("::")[0].split(" ")[1]
- output += self._format_line(first, kernel)
- return output
- """
- Returns the computed dispatch table(str). Note this only include
- runtime keys, registrations to alias keys have been decoded to their
- mapped runtime keys.
- """
- def dispatchTable(self):
- output = self._format_header("Computed Dispatch Table")
- table = self.rawDispatchTable()
- table_entries = table.split("\n")
- regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
- for line in table_entries:
- k = line.split(":")[0]
- if k in self.runtime_keys:
- entry = regex.sub("[", line)
- output += self._format_line(k, entry.split(": ")[1])
- return output
|