_ops.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import contextlib
  2. import ctypes
  3. import inspect
  4. import sys
  5. import types
  6. from abc import ABC
  7. from typing import Any, Dict
  8. import torch._C
  9. from torch import _utils_internal
  10. from torch._functorch.pyfunctorch import dispatch_functorch
  11. # Query `hasattr` only once.
  12. _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
  13. @contextlib.contextmanager
  14. def dl_open_guard():
  15. """
  16. Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
  17. shared library to load custom operators.
  18. """
  19. if not _SET_GLOBAL_FLAGS:
  20. yield
  21. return
  22. old_flags = sys.getdlopenflags()
  23. sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
  24. try:
  25. yield
  26. finally:
  27. sys.setdlopenflags(old_flags)
  28. def has_key(op, k):
  29. return (
  30. torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
  31. or k in op.py_kernels
  32. )
  33. # TODO(voz) We are missing an entire axis of registration - Modes for the python key
  34. class PyOperatorABC(ABC):
  35. def __call__(self, *args, **kwargs):
  36. pass
  37. def py_impl(self, dispatch_key, fn):
  38. pass
  39. def name(self):
  40. pass
  41. is_included_in_alias = torch._C._dispatch_is_included_in_alias
  42. DispatchKey = torch._C.DispatchKey
  43. # Equivalent to computeDispatchTableEntryWithDebug
  44. def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
  45. # 1. (Direct) operator registration
  46. if has_key(op, k):
  47. return k
  48. # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
  49. cand = DispatchKey.CompositeExplicitAutogradNonFunctional
  50. if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
  51. op, cand
  52. ):
  53. return cand
  54. # 2.2 Use CompositeExplicitAutograd kernel if available
  55. cand = DispatchKey.CompositeExplicitAutograd
  56. if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
  57. op, cand
  58. ):
  59. return cand
  60. has_backend_kernel = torch._C._dispatch_has_kernel_for_any_dispatch_key(
  61. op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k)
  62. ) or has_key(op, DispatchKey.CompositeExplicitAutograd)
  63. # 2.3. Use CompositeImplicitAutograd kernel if available
  64. cand = DispatchKey.CompositeImplicitAutogradNestedTensor
  65. if (
  66. (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
  67. and has_key(op, cand)
  68. and not has_backend_kernel
  69. ):
  70. return cand
  71. cand = DispatchKey.CompositeImplicitAutograd
  72. if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
  73. op, cand
  74. ):
  75. if (
  76. k == DispatchKey.AutogradOther
  77. and torch._C._dispatch_has_kernel_for_any_dispatch_key(
  78. op.name(), torch._C._dispatch_autogradother_backends
  79. )
  80. ):
  81. raise RuntimeError("ambiguous autogradother kernel")
  82. elif not has_backend_kernel:
  83. return cand
  84. # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
  85. cand = DispatchKey.Autograd
  86. if is_included_in_alias(k, cand) and has_key(op, cand):
  87. return cand
  88. # Backend fallback
  89. if torch._C._dispatch_has_backend_fallback(k):
  90. # The dispatch key itself will implicitly route to backend fallback.
  91. # This is probably not great for the pure Python implementation.
  92. return k
  93. raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
  94. pyop_namespace = {}
  95. class PyOperator(PyOperatorABC):
  96. def __init__(self, name):
  97. self._name = name
  98. self.table = {}
  99. self.python_key_mode_table = {}
  100. self.functorch_table = {}
  101. # Make _OPNamespace not scream, this whole name based association needs a good hard look
  102. self.__name__ = name
  103. pyop_namespace[name] = self
  104. def fallthrough(self, dispatch_key):
  105. self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)
  106. def py_impl(self, dispatch_key_or_mode_or_transform):
  107. def inner(fn):
  108. if inspect.isclass(dispatch_key_or_mode_or_transform) and issubclass(
  109. dispatch_key_or_mode_or_transform,
  110. torch.utils._python_dispatch.TorchDispatchMode,
  111. ):
  112. mode = dispatch_key_or_mode_or_transform
  113. assert mode not in self.python_key_mode_table
  114. # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
  115. self.python_key_mode_table[mode] = fn
  116. return fn
  117. if isinstance(
  118. dispatch_key_or_mode_or_transform, torch._C._functorch.TransformType
  119. ):
  120. transform = dispatch_key_or_mode_or_transform
  121. self.functorch_table[transform] = fn
  122. return fn
  123. dispatch_key = dispatch_key_or_mode_or_transform
  124. assert (
  125. dispatch_key != torch._C.DispatchKey.Python
  126. ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
  127. assert isinstance(dispatch_key, torch._C.DispatchKey)
  128. assert dispatch_key not in self.table
  129. self.table[dispatch_key] = fn
  130. return fn
  131. return inner
  132. def dispatch(self, dispatch_key, *args, **kwargs):
  133. from torch.utils._python_dispatch import _get_current_dispatch_mode
  134. if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode:
  135. return dispatch_functorch(self, args, kwargs)
  136. if dispatch_key == torch._C.DispatchKey.Python:
  137. # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
  138. curr_mode = _get_current_dispatch_mode()
  139. assert (
  140. curr_mode is not None
  141. ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
  142. assert (
  143. type(curr_mode) in self.python_key_mode_table
  144. ), f"Current active mode {curr_mode} not registered"
  145. # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
  146. return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
  147. assert dispatch_key in self.table, dispatch_key
  148. return self.table[dispatch_key](*args, **kwargs)
  149. def __call__(self, *args, **kwargs):
  150. flat_args = _to_flat_tuple(args, kwargs)
  151. if torch.overrides.has_torch_function(flat_args):
  152. return torch.overrides.handle_torch_function(
  153. self, flat_args, *args, **kwargs
  154. )
  155. dispatch_key_set = _compute_keyset(args, kwargs)
  156. return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
  157. def name(self):
  158. return self.name
  159. # TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
  160. # as opposed to being this sort of explicit thing where ops are a little too key aware...
  161. def _fallthrough_fn(self, operator, dispatch_key):
  162. def inner(*args, **kwargs):
  163. all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key)
  164. all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
  165. args, kwargs
  166. )
  167. return self.dispatch(
  168. all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
  169. )
  170. return inner
  171. def _to_flat_tuple(args, kwargs):
  172. flat_args, _ = torch.utils._pytree.tree_flatten(args)
  173. flat_kwargs, _ = torch.utils._pytree.tree_flatten(kwargs)
  174. flat_all = flat_args + flat_kwargs
  175. return flat_all
  176. def _compute_keyset(args, kwargs):
  177. tensors = _get_tensors(args, kwargs)
  178. return key_extractor(tensors)
  179. def _get_tensors(args, kwargs):
  180. flat_all = _to_flat_tuple(args, kwargs)
  181. tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
  182. return tuple(tensor_args)
  183. # Note - this should maintain identical impl to the C++ dispatcher key extraction logic
  184. # at ATen/core/dispatch/DispatchKeyExtractor.h
  185. def key_extractor(tensors):
  186. key_set = torch._C._dispatch_tls_local_include_set()
  187. for tensor in tensors:
  188. key_set = key_set | torch._C._dispatch_keys(tensor)
  189. key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
  190. return key_set
  191. # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
  192. # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
  193. class OpOverload(PyOperatorABC):
  194. def __init__(self, overloadpacket, op, op_dk, schema, tags):
  195. self._op = op
  196. self._op_dk = op_dk
  197. self._schema = schema
  198. self._overloadpacket = overloadpacket
  199. self._tags = tags
  200. self._overloadname = (
  201. "default" if schema.overload_name == "" else schema.overload_name
  202. )
  203. self._name = self._schema.name
  204. if schema.overload_name:
  205. self._name += "." + schema.overload_name
  206. self.py_kernels: Dict[torch._C.DispatchKey, Any] = {} # type: ignore[name-defined]
  207. self.__name__ = "{}.{}".format(
  208. self._schema.name.split("::")[1], self._overloadname
  209. )
  210. # TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
  211. self.python_key_mode_table = {}
  212. self.__module__ = overloadpacket.__module__
  213. op.__module__ = overloadpacket.__module__
  214. self.__qualname__ = self._name
  215. self.__annotations__ = {}
  216. # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
  217. self._dispatch_cache = {}
  218. # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
  219. is_write = None
  220. for a in self._schema.arguments:
  221. if a.alias_info is None:
  222. continue
  223. if is_write is None:
  224. is_write = a.alias_info.is_write
  225. else:
  226. # We will conservatively call mixed mutable/non-mutable
  227. # aliased inputs as NOT a view
  228. is_write = a.alias_info.is_write or is_write
  229. self.is_view = is_write is not None and not is_write
  230. # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
  231. def __deepcopy__(self, memo=None):
  232. return self
  233. def __repr__(self):
  234. return "<OpOverload(op='{}.{}', overload='{}')>".format(
  235. *self._schema.name.split("::"), self._overloadname
  236. )
  237. def __call__(self, *args, **kwargs):
  238. return self._op(*args, **kwargs or {})
  239. def __hash__(self):
  240. return hash(self._op)
  241. # `my_namespace.my_op_name.overload_name`
  242. def __str__(self):
  243. return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
  244. @property
  245. def namespace(self):
  246. return self._schema.name.split("::")[0]
  247. def decompose(self, *args, **kwargs):
  248. dk = torch._C.DispatchKey.CompositeImplicitAutograd
  249. if dk in self.py_kernels:
  250. # NB: This branch is not too necessary anymore, because we can
  251. # apply Python CompositeImplicitAutograd *before* tracing
  252. # using Python dispatcher (also taking advantage of the autograd
  253. # formula). But it's included for completeness
  254. return self.py_kernels[dk](*args, **kwargs)
  255. elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
  256. return self._op_dk(dk, *args, **kwargs)
  257. else:
  258. return NotImplemented
  259. def py_impl(self, dispatch_key_or_mode):
  260. def inner(fn):
  261. if inspect.isclass(dispatch_key_or_mode) and issubclass(
  262. dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
  263. ):
  264. mode = dispatch_key_or_mode
  265. assert mode not in self.python_key_mode_table
  266. # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
  267. self.python_key_mode_table[mode] = fn
  268. self._dispatch_cache.clear()
  269. return fn
  270. assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey)
  271. assert (
  272. dispatch_key_or_mode != torch._C.DispatchKey.Python
  273. ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
  274. if dispatch_key_or_mode in self.py_kernels:
  275. raise RuntimeError(
  276. f"Trying to override a python impl for {dispatch_key_or_mode} on operator {self._name}"
  277. )
  278. self.py_kernels[dispatch_key_or_mode] = fn
  279. self._dispatch_cache.clear()
  280. return fn
  281. return inner
  282. # Remove a dispatch key from the dispatch cache. This will force it to get
  283. # recomputed the next time. Does nothing
  284. # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
  285. # calling _del_dispatch on that key is NOT sufficient to apply your change,
  286. # because a single registration may affect MULTIPLE dispatch keys (e.g.,
  287. # registering Autograd affects AutogradCPU). del_dispatch is to be used
  288. # only if you are specifically modifying how get_dispatch handles a
  289. # particular input 'key'.
  290. def _uncache_dispatch(self, key):
  291. self._dispatch_cache.pop(key, None)
  292. # This implements the pre-computation logic for the Python dispatcher.
  293. def _get_dispatch(self, key):
  294. # This is only called upon a cache miss
  295. assert key not in self._dispatch_cache, f"{self} {key}"
  296. if key == torch._C.DispatchKey.Python:
  297. if not self.python_key_mode_table:
  298. self._dispatch_cache[key] = key
  299. return key
  300. def handler(*args, **kwargs):
  301. from torch.utils._python_dispatch import _get_current_dispatch_mode
  302. # TODO: We also need to handle tensor subclasses here
  303. # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
  304. curr_mode = type(_get_current_dispatch_mode())
  305. assert (
  306. curr_mode is not None
  307. ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
  308. if curr_mode not in self.python_key_mode_table:
  309. # TODO: This path is slow, should generally encourage this
  310. # case to not happen
  311. return self._op_dk(key, *args, **kwargs)
  312. # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
  313. return self.python_key_mode_table[curr_mode](*args, **kwargs)
  314. self._dispatch_cache[key] = handler
  315. return handler
  316. final_key = resolve_key(self, key)
  317. # TODO: We could potentially have lots of debugging wrappers against
  318. # dispatch keys; design some general registration mechanism instead of
  319. # having if statement for each of them
  320. if key == torch._C.DispatchKey.Functionalize:
  321. import torch._dispatch.python as pydispatch
  322. if pydispatch.CROSSREF_FUNCTIONALIZE:
  323. handler = pydispatch.make_crossref_functionalize(self, final_key)
  324. self._dispatch_cache[key] = handler
  325. return handler
  326. # print(self, key, final_key)
  327. r = self.py_kernels.get(final_key, final_key)
  328. self._dispatch_cache[key] = r
  329. return r
  330. def name(self):
  331. return self._name
  332. @property
  333. def overloadpacket(self):
  334. return self._overloadpacket
  335. @property
  336. def op(self):
  337. return self._op
  338. @property
  339. def tags(self):
  340. return self._tags
  341. # TODO: add more methods to expose information about input and output arguments
  342. # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
  343. # You can obtain an OpOverload object through attribute query.
  344. class OpOverloadPacket:
  345. def __init__(self, qualified_op_name, op_name, op, overload_names):
  346. # These attributes are accessible on the object through the properties
  347. # defined below but are immutable
  348. self._qualified_op_name = qualified_op_name
  349. self.__name__ = op_name
  350. self._op = op
  351. self._overload_names = overload_names
  352. self._dir = []
  353. # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
  354. def __deepcopy__(self, memo=None):
  355. return self
  356. def __repr__(self):
  357. return "<OpOverloadPacket(op='{}.{}')>".format(
  358. *self._qualified_op_name.split("::")
  359. )
  360. def __hash__(self):
  361. return hash(self._op)
  362. def __str__(self):
  363. return "{}.{}".format(*self._qualified_op_name.split("::"))
  364. @property
  365. def op(self):
  366. return self._op
  367. def __getattr__(self, key):
  368. # It is not a valid op_name when __file__ is passed in
  369. if key == "__file__":
  370. return "torch.ops"
  371. # ensure that query for dunder attributes that does not exist on
  372. # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
  373. # `_get_operation_overload` (which is an expensive operation).
  374. # This is done to prevent any potential slowdown. This list can be extended
  375. # if there exists other attributes like `__name__` that only exist on self._op and not on the
  376. # opoverloadpacket.
  377. # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
  378. try:
  379. if key.startswith("__"):
  380. return getattr(self._op, key)
  381. except AttributeError:
  382. # for consistency because it seems weird to
  383. # throw an attribute error with a message containing
  384. # an object name different from the one the attribute
  385. # query was performed on.
  386. raise AttributeError(
  387. "'{}' can't have an overload name beginning with '__' and the "
  388. "underlying op {} has no attribute {} either.".format(
  389. str(self), str(self._op), key
  390. )
  391. ) from None
  392. try:
  393. # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
  394. use_key = "" if key == "default" else key
  395. # TODO: disallow access to overloads registered by JIT
  396. op_, op_dk_, tags = torch._C._get_operation_overload(
  397. self._qualified_op_name, use_key
  398. )
  399. schema = torch._C._get_schema(self._qualified_op_name, use_key)
  400. overload = OpOverload(self, op_, op_dk_, schema, tags)
  401. # cache the overload object
  402. setattr(self, key, overload)
  403. self._dir.append(key)
  404. return overload
  405. except RuntimeError:
  406. raise AttributeError(
  407. "The underlying op of '{}' has no overload name '{}'".format(
  408. str(self), key
  409. )
  410. ) from None
  411. def __iter__(self):
  412. return iter(self._dir)
  413. def __call__(self, *args, **kwargs):
  414. # overloading __call__ to ensure torch.ops.foo.bar()
  415. # is still callable from JIT
  416. # We save the function ptr as the `op` attribute on
  417. # OpOverloadPacket to access it here.
  418. return self._op(*args, **kwargs or {})
  419. # TODO: use this to make a __dir__
  420. def overloads(self):
  421. return [n if n else "default" for n in self._overload_names]
  422. # Resolution of torch.fn is different from torch.ops.aten.fn
  423. # torch.fn uses the Python argparser, matches with the
  424. # appropriate schema, and calls into the unboxed version of the method
  425. # torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
  426. # JIT creates a stack of all the overloads and then tries to match the
  427. # correct one at runtime and always calls into the boxed version of the method
  428. # Autograd codegen creates VariableType, TracerType,
  429. # inplace or view type and python bindings.
  430. # Aten codegen generates tensor methods for the the tensor class.
  431. # _OpNamespace is a subclass of ModuleType because the torch script
  432. # allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
  433. # to work from script, we need to ensure ops and foo are modules
  434. class _OpNamespace(types.ModuleType):
  435. """
  436. An op namespace to dynamically bind Operators into Python.
  437. Say a user has created a custom Operator called "my_namespace::my_op". To
  438. call this op, the user will write torch.ops.my_namespace.my_op(...).
  439. At startup, this operation will not yet be bound into Python. Instead, the
  440. following sequence of magic tricks will occur:
  441. 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
  442. on the `torch.ops` object, which will create a new `_OpNamespace`
  443. object called `my_namespace` and set it as an attribute on the `ops`
  444. object.
  445. 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
  446. the `my_namespace` object, which will retrieve the operation via
  447. `torch.get_operation`, a function bound from C++, and then in a similar
  448. fashion bind this new object onto the `my_namespace` object.
  449. 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
  450. and subsequent accesses will incur no further lookup (the namespace and
  451. operation will already exist).
  452. """
  453. def __init__(self, name):
  454. super().__init__("torch.ops." + name)
  455. self.name = name
  456. self._dir = []
  457. def __iter__(self):
  458. return iter(self._dir)
  459. def __getattr__(self, op_name):
  460. # It is not a valid op_name when __file__ is passed in
  461. if op_name == "__file__":
  462. return "torch.ops"
  463. elif op_name == "__origin__":
  464. raise AttributeError()
  465. # Get the op `my_namespace::my_op` if available. This will also check
  466. # for overloads and raise an exception if there are more than one.
  467. namespace_name = self.name
  468. qualified_op_name = "{}::{}".format(namespace_name, op_name)
  469. try:
  470. op, overload_names = torch._C._jit_get_operation(qualified_op_name)
  471. except RuntimeError as e:
  472. # Turn this into AttributeError so getattr(obj, key, default)
  473. # works (this is called by TorchScript with __origin__)
  474. raise AttributeError(
  475. f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
  476. ) from e
  477. # let the script frontend know that op is identical to the builtin op
  478. # with qualified_op_name
  479. torch.jit._builtins._register_builtin(op, qualified_op_name)
  480. op.__module__ = self.__module__ + "." + namespace_name
  481. opoverloadpacket = OpOverloadPacket(
  482. qualified_op_name, op_name, op, overload_names
  483. )
  484. opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
  485. # cache the opoverloadpacket to ensure that each op corresponds to
  486. # a unique OpOverloadPacket object
  487. setattr(self, op_name, opoverloadpacket)
  488. self._dir.append(op_name)
  489. return opoverloadpacket
  490. class _PyOpNamespace(_OpNamespace):
  491. def __init__(self):
  492. super().__init__("torch.ops")
  493. self.pyop_namespace = pyop_namespace
  494. class _Ops(types.ModuleType):
  495. __file__ = "_ops.py"
  496. def __init__(self):
  497. super().__init__("torch.ops")
  498. self.loaded_libraries = set()
  499. self.pyops = _PyOpNamespace()
  500. self._dir = []
  501. def __getattr__(self, name):
  502. # Check if the name is a pyop
  503. if name in self.pyops.pyop_namespace:
  504. return self.pyops.pyop_namespace[name]
  505. # Here we are creating `torch.ops.my_namespace`
  506. namespace = _OpNamespace(name)
  507. setattr(self, name, namespace)
  508. self._dir.append(name)
  509. return namespace
  510. def __iter__(self):
  511. return iter(self._dir)
  512. def load_library(self, path):
  513. """
  514. Loads a shared library from the given path into the current process.
  515. The library being loaded may run global initialization code to register
  516. custom operators with the PyTorch JIT runtime. This allows dynamically
  517. loading custom operators. For this, you should compile your operator
  518. and the static registration code into a shared library object, and then
  519. call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
  520. shared object.
  521. After the library is loaded, it is added to the
  522. ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
  523. for the paths of all libraries loaded using this function.
  524. Args:
  525. path (str): A path to a shared library to load.
  526. """
  527. if sys.executable == "torch_deploy":
  528. return
  529. path = _utils_internal.resolve_library_path(path)
  530. with dl_open_guard():
  531. # Import the shared library into the process, thus running its
  532. # static (global) initialization code in order to register custom
  533. # operators with the JIT.
  534. ctypes.CDLL(path)
  535. self.loaded_libraries.add(path)
  536. # The ops "namespace"
  537. ops = _Ops()