library.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from ._ops import OpOverload
  2. from typing import Set
  3. import traceback
  4. import torch
  5. __all__ = ['Library', 'impl', 'define']
  6. # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
  7. # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
  8. # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
  9. # libraries calling into kernels not intended to be called.
  10. _impls: Set[str] = set()
  11. # prim is reserved by TorchScript interpreter
  12. _reserved_namespaces = ['prim']
  13. class Library:
  14. """
  15. A class to create libraries that can be used to register new operators or
  16. override operators in existing libraries from Python.
  17. A user can optionally pass in a dispatch keyname if they only want to register
  18. kernels corresponding to only one specific dispatch key.
  19. To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
  20. To create a new library (with name ns) to register new operators, set the kind to "DEF".
  21. Args:
  22. ns: library name
  23. kind: "DEF", "IMPL" (default: "IMPL")
  24. dispatch_key: PyTorch dispatch key (default: "")
  25. """
  26. def __init__(self, ns, kind, dispatch_key=""):
  27. if kind != "IMPL" and kind != "DEF":
  28. raise ValueError("Unsupported kind: ", kind)
  29. if ns in _reserved_namespaces and kind == "DEF":
  30. raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
  31. frame = traceback.extract_stack(limit=3)[0]
  32. filename, lineno = frame.filename, frame.lineno
  33. self.m = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
  34. self.ns = ns
  35. self._op_impls = set()
  36. self.kind = kind
  37. self.dispatch_key = dispatch_key
  38. def __repr__(self):
  39. return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key)
  40. def define(self, schema, alias_analysis=""):
  41. r'''Defines a new operator and its semantics in the ns namespace.
  42. Args:
  43. schema: function schema to define a new operator.
  44. alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
  45. inferred from the schema (default behavior) or not ("CONSERVATIVE").
  46. Returns:
  47. name of the operator as inferred from the schema.
  48. Example::
  49. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY)
  50. >>> my_lib = Library("foo", "DEF")
  51. >>> my_lib.define("sum(Tensor self) -> Tensor")
  52. '''
  53. # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
  54. # AliasAnalysis type in C++
  55. if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
  56. raise RuntimeError("Invalid alias_analysis type {}".format(alias_analysis))
  57. return self.m.define(schema, alias_analysis)
  58. def impl(self, op_name, fn, dispatch_key=''):
  59. r'''Registers the function implementation for an operator defined in the library.
  60. Args:
  61. op_name: operator name (along with the overload) or OpOverload object.
  62. fn: function that's the operator implementation for the input dispatch key.
  63. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  64. the dispatch key that the library was created with.
  65. Example::
  66. >>> # xdoctest: +SKIP
  67. >>> my_lib = Library("aten", "IMPL")
  68. >>> def div_cpu(self, other):
  69. >>> return self * (1 / other)
  70. >>> my_lib.impl("div.Tensor", "CPU")
  71. '''
  72. if not callable(fn):
  73. raise TypeError("Input function is required to be a callable but found type {}".format(type(fn)))
  74. if dispatch_key == '':
  75. dispatch_key = self.dispatch_key
  76. if isinstance(op_name, str):
  77. name = op_name
  78. elif isinstance(op_name, OpOverload):
  79. name = op_name._schema.name
  80. overload_name = op_name._schema.overload_name
  81. if overload_name != '':
  82. name = name + '.' + overload_name
  83. else:
  84. raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
  85. key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
  86. if key in _impls:
  87. # TODO: in future, add more info about where the existing function is registered (this info is
  88. # today already returned by the C++ warning when impl is called but we error out before that)
  89. raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
  90. "'s behavior for {} dispatch key and {} namespace.".
  91. format(name.split("::")[-1], dispatch_key, self.ns))
  92. if dispatch_key == "Meta":
  93. dispatcher_op_name = name
  94. if '::' not in dispatcher_op_name:
  95. dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
  96. # Internally, we shouldn't be registering meta kernels for any operators that
  97. # have CompositeImplicitAutograd kernels.
  98. # Instead, we should be letting those decompositions run, and writing meta kernels
  99. # only for the base operators.
  100. if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
  101. raise RuntimeError(
  102. f"We should not register a meta kernel directly to the operator '{name}',"
  103. " because it has a CompositeImplicitAutograd kernel in core."
  104. " Instead we should let the operator decompose, and ensure that we have meta kernels"
  105. " for the base ops that it decomposes into.")
  106. self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn)
  107. _impls.add(key)
  108. self._op_impls.add(key)
  109. def __del__(self):
  110. # _op_impls might not have been initialized if an error was thrown in __init__
  111. _op_impls_ = getattr(self, '_op_impls', None)
  112. if _op_impls_:
  113. for key in self._op_impls:
  114. _impls.remove(key)
  115. del self.m
  116. # decorator to register python functions for library ops
  117. # Note: this decorator API should remain consistent with `Library.impl` API
  118. def impl(lib, name, dispatch_key=""):
  119. def wrap(f):
  120. lib.impl(name, f, dispatch_key)
  121. return f
  122. return wrap
  123. def define(lib, schema, alias_analysis=""):
  124. def wrap(f):
  125. name = lib.define(schema, alias_analysis)
  126. lib.impl(name, f)
  127. return f
  128. return wrap